blob: 3ba079bbc21bbf59e7ebcc7df15b6e285e347f2f [file] [log] [blame]
Nikolaj Jensenfab6c212023-06-27 14:13:24 +01001/*
2 * Copyright (c) 2023 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#ifndef CKW_INCLUDE_CKW_KERNELWRITERHELPER_H
26#define CKW_INCLUDE_CKW_KERNELWRITERHELPER_H
27
28#include "ckw/KernelWriter.h"
29#include "ckw/TensorOperand.h"
30#include "ckw/TileOperand.h"
31
32#include <iostream>
33#include <type_traits>
34
Nikolaj Jensenfab6c212023-06-27 14:13:24 +010035/*
36 * By including this header file you will be able to supplement the default
37 * Compute Kernel Writer API with additional syntax to help ease the use of CKW.
38 *
39 * To use the KernelWriterHelper you need to wrap your instance of KernelWriter
40 * (or any class deriving from KernelWriter):
41 * KernelWriterHelper<KernelWriter> writer;
42 * The resulting writer object comprises the original KernelWriter
43 * functionality (drop-in replacement), but extends the syntax as follows.
44 *
45 * Common functions/operators have natural syntax:
46 * 1. Unary expressions:
47 * writer.op_assign(dst, !src); // Logical NOT
48 * writer.op_assign(dst, ~src); // Bitwise NOT
49 *
50 * 2. Binary expressions:
51 * writer.op_assign(dst, lhs + rhs); // Addition
52 * writer.op_assign(dst, lhs - rhs); // Subtraction
53 * writer.op_assign(dst, lhs * rhs); // Multiplication
54 * writer.op_assign(dst, lhs / rhs); // Division
55 * writer.op_assign(dst, lhs % rhs); // Modulo
56 * writer.op_assign(dst, lhs == rhs); // Equality
57 * writer.op_assign(dst, lhs < rhs); // Less-than
58 * writer.op_assign(dst, lhs <= rhs); // Less-than-or-equal
59 * writer.op_assign(dst, lhs > rhs); // Greater-than
60 * writer.op_assign(dst, lhs >= rhs); // Greater-than-or-equal
61 * writer.op_assign(dst, lhs ^ rhs); // Bitwise XOR
62 * writer.op_assign(dst, logical_and(lhs, rhs)); // Logical AND
63 * writer.op_assign(dst, logical_or(lhs, rhs)); // Logical OR
64 *
65 * 3. Unary elementwise functions:
66 * writer.op_assign(dst, exp(src)); // Exponent
67 * writer.op_assign(dst, tanh(src)); // Hyperbolic tangent
68 * writer.op_assign(dst, sqrt(src)); // Square root
69 * writer.op_assign(dst, erf(src)); // Error function
70 * writer.op_assign(dst, fabs(src)); // Absolute of floating-point number
71 * writer.op_assign(dst, log(src)); // Natural logarithm
72 * writer.op_assign(dst, round(src)); // Round
73 * writer.op_assign(dst, sizeOf(src)); // sizeof
74 *
75 * 4. Binary elementwise functions:
76 * writer.op_assign(dst, max(first, second)); // Max
77 * writer.op_assign(dst, min(first, second)); // Min
78 *
79 * 5. Ternary elementwise functions:
80 * writer.op_assign(dst, select(first, second, third)); // Select
81 *
82 * NOTE: All the above examples support nesting, so you could write
83 * something like: writer.op_assign(dst, src * (log(arg) + sqrt(abs(arg)));
84 *
85 *
86 * 6. If-statements. The preceding syntax also allows easier writing of if-statements:
87 * writer.op_if(<cond>, <body>);
88 *
89 * For example:
90 * writer.op_if(exp(first_arg) == dst, [&]{
91 * //...
92 * }).op_else_if(exp(first_arg) > dst, [&]{
93 * //...
94 * }).op_else([&] {
95 * //...
96 * });
97 *
98 * 7. For-loops. A similar syntax exists for for-loops:
99 * writer.op_for_loop(<cond>, <updater>, <body>);
100 *
101 * For example:
102 * writer.op_for_loop(index < limit, index += step, [&]{
103 * //...
104 * });
105 *
106 * NOTE: There are limitations on the for-loop <cond> and <updater> parameters.
107 * In neither the <cond> (Binary expression) or <updater> (Increment/Decrement)
108 * is it allowed to use nesting. For example, `(index + other) < limit` and
109 * `index < round(limit)` are invalid <cond> parameters. This is because the
110 * semantics of for-loops rely on the condition being evaluated at every iteration,
111 * but as temporary variables might be defined for nested expressions the semantics
112 * cannot be guaranteed.
113 */
114
115namespace ckw
116{
117
118// ==================================================
119// Type traits
120// ==================================================
121
122/** Specifies if the type can be used as an operand for functions (e.g. max), operations (e.g. *), or assignments. */
123template <typename T>
124struct can_be_operand : ::std::false_type
125{
126};
127
128/** Specifies if the type can be assigned/written to. */
129template <typename T>
130struct can_be_assigned : ::std::false_type
131{
132};
133
134template <>
135struct can_be_operand<TileOperand &> : ::std::true_type
136{
137};
138
139template <>
140struct can_be_assigned<TileOperand &> : ::std::true_type
141{
142};
143
144// ==================================================
145// Assignment
146// ==================================================
147
148/** AST node for assignments.
149 *
150 * Note that \p TRight must be an operand, and \p TLeft must be assignable.
151 *
152 * @tparam TLeft The type of the destination of the assignment.
153 * @tparam TRight The type of the source assigned to the destination.
154 */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100155template <typename TLeft,
156 typename TRight,
157 typename = ::std::enable_if<can_be_operand<TRight>::value && can_be_assigned<TLeft>::value>>
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100158struct Assignment
159{
160 TLeft lhs;
161 TRight rhs;
162 AssignmentOp opcode;
163};
164
165/** Represents the expression: `\p lhs += \p rhs`.
166 *
167 * @tparam TLeft The type of the LHS of the assignment.
168 * @tparam TRight The type of the RHS of the assignment.
169 * @param[in] lhs The LHS of the assignment.
170 * @param[in] rhs The RHS of the assignment.
171 * @return The resulting AST node.
172 */
173template <typename TLeft, typename TRight>
174inline Assignment<TLeft, TRight> operator+=(TLeft &&lhs, TRight &&rhs)
175{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100176 return Assignment<TLeft, TRight>{std::forward<TLeft>(lhs), std::forward<TRight>(rhs), AssignmentOp::Increment};
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100177}
178
179/** Represents the expression: `\p lhs -= \p rhs`.
180 *
181 * @tparam TLeft The type of the LHS of the assignment.
182 * @tparam TRight The type of the RHS of the assignment.
183 * @param[in] lhs The LHS of the assignment.
184 * @param[in] rhs The RHS of the assignment.
185 * @return The resulting AST node.
186 */
187template <typename TLeft, typename TRight>
188inline Assignment<TLeft, TRight> operator-=(TLeft &&lhs, TRight &&rhs)
189{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100190 return Assignment<TLeft, TRight>{std::forward<TLeft>(lhs), std::forward<TRight>(rhs), AssignmentOp::Decrement};
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100191}
192
193// ==================================================
194// Unary expression
195// ==================================================
196
197/** AST node for unary expressions.
198 *
199 * Note that \p TSrc must be an operand.
200 *
201 * @tparam TSrc The type of the argument to the expression.
202 */
203template <typename TSrc, typename = ::std::enable_if<can_be_operand<TSrc>::value>>
204struct UnaryExpression
205{
206 TSrc src;
207 UnaryOp opcode;
208};
209
210template <typename TLeft>
211struct can_be_operand<UnaryExpression<TLeft>> : ::std::true_type
212{
213};
214
215/** Represents the expression: `!\p src`.
216 *
217 * @tparam TSrc The type of the argument.
218 * @param[in] src The argument.
219 * @return The resulting AST node.
220 */
221template <typename TSrc>
222inline UnaryExpression<TSrc> operator!(TSrc &&src)
223{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100224 return UnaryExpression<TSrc>{std::forward<TSrc>(src), UnaryOp::LogicalNot};
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100225}
226
227/** Represents the expression: `~\p src`.
228 *
229 * @tparam TSrc The type of the argument.
230 * @param[in] src The argument.
231 * @return The resulting AST node.
232 */
233template <typename TSrc>
234inline UnaryExpression<TSrc> operator~(TSrc &&src)
235{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100236 return UnaryExpression<TSrc>{std::forward<TSrc>(src), UnaryOp::BitwiseNot};
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100237}
238
239// ==================================================
240// Binary expressions
241// ==================================================
242
243/** AST node for binary expressions.
244 *
245 * Note that both \p TLeft and \p TRight must be operands.
246 *
247 * @tparam TLeft The type of the left argument of the expression.
248 * @tparam TRight The type of the right argument of the expression.
249 */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100250template <typename TLeft,
251 typename TRight,
252 typename = ::std::enable_if_t<can_be_operand<TLeft>::value && can_be_operand<TRight>::value>>
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100253struct BinaryExpression
254{
255 TLeft lhs;
256 TRight rhs;
257 BinaryOp opcode;
258};
259
260template <typename TLeft, typename TRight>
261struct can_be_operand<BinaryExpression<TLeft, TRight>> : ::std::true_type
262{
263};
264
265/** Represents the expression: `\p lhs + \p rhs`.
266 *
267 * @tparam TLeft The type of the LHS of the expression.
268 * @tparam TRight The type of the RHS of the expression.
269 * @param[in] lhs The LHS of the expression.
270 * @param[in] rhs The RHS of the expression.
271 * @return The resulting AST node.
272 */
273template <typename TLeft, typename TRight>
274inline BinaryExpression<TLeft, TRight> operator+(TLeft &&lhs, TRight &&rhs)
275{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100276 return BinaryExpression<TLeft, TRight>{std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::Add};
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100277}
278
279/** Represents the expression: `\p lhs - \p rhs`.
280 *
281 * @tparam TLeft The type of the LHS of the expression.
282 * @tparam TRight The type of the RHS of the expression.
283 * @param[in] lhs The LHS of the expression.
284 * @param[in] rhs The RHS of the expression.
285 * @return The resulting AST node.
286 */
287template <typename TLeft, typename TRight>
288inline BinaryExpression<TLeft, TRight> operator-(TLeft &&lhs, TRight &&rhs)
289{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100290 return BinaryExpression<TLeft, TRight>{std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::Sub};
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100291}
292
293/** Represents the expression: `\p lhs * \p rhs`.
294 *
295 * @tparam TLeft The type of the LHS of the expression.
296 * @tparam TRight The type of the RHS of the expression.
297 * @param[in] lhs The LHS of the expression.
298 * @param[in] rhs The RHS of the expression.
299 * @return The resulting AST node.
300 */
301template <typename TLeft, typename TRight>
302inline BinaryExpression<TLeft, TRight> operator*(TLeft &&lhs, TRight &&rhs)
303{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100304 return BinaryExpression<TLeft, TRight>{std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::Mul};
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100305}
306
307/** Represents the expression: `\p lhs / \p rhs`.
308 *
309 * @tparam TLeft The type of the LHS of the expression.
310 * @tparam TRight The type of the RHS of the expression.
311 * @param[in] lhs The LHS of the expression.
312 * @param[in] rhs The RHS of the expression.
313 * @return The resulting AST node.
314 */
315template <typename TLeft, typename TRight>
316inline BinaryExpression<TLeft, TRight> operator/(TLeft &&lhs, TRight &&rhs)
317{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100318 return BinaryExpression<TLeft, TRight>{std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::Div};
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100319}
320
321/** Represents the expression: `\p lhs % \p rhs`.
322 *
323 * @tparam TLeft The type of the LHS of the expression.
324 * @tparam TRight The type of the RHS of the expression.
325 * @param[in] lhs The LHS of the expression.
326 * @param[in] rhs The RHS of the expression.
327 * @return The resulting AST node.
328 */
329template <typename TLeft, typename TRight>
330inline BinaryExpression<TLeft, TRight> operator%(TLeft &&lhs, TRight &&rhs)
331{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100332 return BinaryExpression<TLeft, TRight>{std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::Mod};
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100333}
334
335/** Represents the expression: `\p lhs == \p rhs`.
336 *
337 * @tparam TLeft The type of the LHS of the expression.
338 * @tparam TRight The type of the RHS of the expression.
339 * @param[in] lhs The LHS of the expression.
340 * @param[in] rhs The RHS of the expression.
341 * @return The resulting AST node.
342 */
343template <typename TLeft, typename TRight>
344inline BinaryExpression<TLeft, TRight> operator==(TLeft &&lhs, TRight &&rhs)
345{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100346 return BinaryExpression<TLeft, TRight>{std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::Equal};
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100347}
348
349/** Represents the expression: `\p lhs < \p rhs`.
350 *
351 * @tparam TLeft The type of the LHS of the expression.
352 * @tparam TRight The type of the RHS of the expression.
353 * @param[in] lhs The LHS of the expression.
354 * @param[in] rhs The RHS of the expression.
355 * @return The resulting AST node.
356 */
357template <typename TLeft, typename TRight>
358inline BinaryExpression<TLeft, TRight> operator<(TLeft &&lhs, TRight &&rhs)
359{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100360 return BinaryExpression<TLeft, TRight>{std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::Less};
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100361}
362
363/** Represents the expression: `\p lhs <= \p rhs`.
364 *
365 * @tparam TLeft The type of the LHS of the expression.
366 * @tparam TRight The type of the RHS of the expression.
367 * @param[in] lhs The LHS of the expression.
368 * @param[in] rhs The RHS of the expression.
369 * @return The resulting AST node.
370 */
371template <typename TLeft, typename TRight>
372inline BinaryExpression<TLeft, TRight> operator<=(TLeft &&lhs, TRight &&rhs)
373{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100374 return BinaryExpression<TLeft, TRight>{std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::LessEqual};
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100375}
376
377/** Represents the expression: `\p lhs > \p rhs`.
378 *
379 * @tparam TLeft The type of the LHS of the expression.
380 * @tparam TRight The type of the RHS of the expression.
381 * @param[in] lhs The LHS of the expression.
382 * @param[in] rhs The RHS of the expression.
383 * @return The resulting AST node.
384 */
385template <typename TLeft, typename TRight>
386inline BinaryExpression<TLeft, TRight> operator>(TLeft &&lhs, TRight &&rhs)
387{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100388 return BinaryExpression<TLeft, TRight>{std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::Greater};
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100389}
390
391/** Represents the expression: `\p lhs >= \p rhs`.
392 *
393 * @tparam TLeft The type of the LHS of the expression.
394 * @tparam TRight The type of the RHS of the expression.
395 * @param[in] lhs The LHS of the expression.
396 * @param[in] rhs The RHS of the expression.
397 * @return The resulting AST node.
398 */
399template <typename TLeft, typename TRight>
400inline BinaryExpression<TLeft, TRight> operator>=(TLeft &&lhs, TRight &&rhs)
401{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100402 return BinaryExpression<TLeft, TRight>{std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::GreaterEqual};
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100403}
404
405/** Represents the expression: `\p lhs ^ \p rhs`.
406 *
407 * @tparam TLeft The type of the LHS of the expression.
408 * @tparam TRight The type of the RHS of the expression.
409 * @param[in] lhs The LHS of the expression.
410 * @param[in] rhs The RHS of the expression.
411 * @return The resulting AST node.
412 */
413template <typename TLeft, typename TRight>
414inline BinaryExpression<TLeft, TRight> operator^(TLeft &&lhs, TRight &&rhs)
415{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100416 return BinaryExpression<TLeft, TRight>{std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::BitwiseXOR};
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100417}
418
419/** Represents the expression: `\p lhs && \p rhs`.
420 *
421 * @tparam TLeft The type of the LHS of the expression.
422 * @tparam TRight The type of the RHS of the expression.
423 * @param[in] lhs The LHS of the expression.
424 * @param[in] rhs The RHS of the expression.
425 * @return The resulting AST node.
426 */
427template <typename TLeft, typename TRight>
428inline BinaryExpression<TLeft, TRight> logical_and(TLeft &&lhs, TRight &&rhs)
429{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100430 return BinaryExpression<TLeft, TRight>{std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::LogicalAnd};
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100431}
432
433/** Represents the expression: `\p lhs && \p rhs`.
434 *
435 * @tparam TLeft The type of the LHS of the expression.
436 * @tparam TRight The type of the RHS of the expression.
437 * @param[in] lhs The LHS of the expression.
438 * @param[in] rhs The RHS of the expression.
439 * @return The resulting AST node.
440 */
441template <typename TLeft, typename TRight, typename... TOps>
442inline BinaryExpression<BinaryExpression<TLeft, TRight>, TOps...> logical_and(TLeft &&lhs, TRight &&rhs, TOps &&...ops)
443{
444 return logical_and(
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100445 BinaryExpression<TLeft, TRight>{std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::LogicalAnd},
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100446 std::forward<TOps>(ops)...);
447}
448
449/** Represents the expression: `\p lhs || \p rhs`.
450 *
451 * @tparam TLeft The type of the LHS of the expression.
452 * @tparam TRight The type of the RHS of the expression.
453 * @param[in] lhs The LHS of the expression.
454 * @param[in] rhs The RHS of the expression.
455 * @return The resulting AST node.
456 */
457template <typename TLeft, typename TRight>
458inline BinaryExpression<TLeft, TRight> logical_or(TLeft &&lhs, TRight &&rhs)
459{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100460 return BinaryExpression<TLeft, TRight>{std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::LogicalOr};
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100461}
462
463/** Represents the expression: `\p lhs || \p rhs`.
464 *
465 * @tparam TLeft The type of the LHS of the expression.
466 * @tparam TRight The type of the RHS of the expression.
467 * @param[in] lhs The LHS of the expression.
468 * @param[in] rhs The RHS of the expression.
469 * @return The resulting AST node.
470 */
471template <typename TLeft, typename TRight, typename... TOps>
472inline BinaryExpression<BinaryExpression<TLeft, TRight>, TOps...> logical_or(TLeft &&lhs, TRight &&rhs, TOps &&...ops)
473{
474 return logical_or(
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100475 BinaryExpression<TLeft, TRight>{std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::LogicalOr},
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100476 std::forward<TOps>(ops)...);
477}
478
479// ==================================================
480// Unary elementwise functions
481// ==================================================
482
483/** AST node for unary elementwise functions.
484 *
485 * Note that \p TSrc must be an operand.
486 *
487 * @tparam TSrc The type of the argument to the function.
488 */
489template <typename TSrc, typename = ::std::enable_if<can_be_operand<TSrc>::value>>
490struct UnaryElementwiseFunction
491{
492 TSrc src;
493 UnaryFunction opcode;
494};
495
496template <typename TLeft>
497struct can_be_operand<UnaryElementwiseFunction<TLeft>> : ::std::true_type
498{
499};
500
501/** Represents the expression: `exp(\p src)`.
502 *
503 * @tparam TSrc The type of the argument.
504 * @param[in] src The argument.
505 * @return The resulting AST node.
506 */
507template <typename TSrc>
508UnaryElementwiseFunction<TSrc> exp(TSrc &&src)
509{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100510 return UnaryElementwiseFunction<TSrc>{std::forward<TSrc>(src), UnaryFunction::Exp};
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100511}
512
513/** Represents the expression: `tanh(\p src)`.
514 *
515 * @tparam TSrc The type of the argument.
516 * @param[in] src The argument.
517 * @return The resulting AST node.
518 */
519template <typename TSrc>
520UnaryElementwiseFunction<TSrc> tanh(TSrc &&src)
521{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100522 return UnaryElementwiseFunction<TSrc>{std::forward<TSrc>(src), UnaryFunction::Tanh};
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100523}
524
525/** Represents the expression: `sqrt(\p src)`.
526 *
527 * @tparam TSrc The type of the argument.
528 * @param[in] src The argument.
529 * @return The resulting AST node.
530 */
531template <typename TSrc>
532UnaryElementwiseFunction<TSrc> sqrt(TSrc &&src)
533{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100534 return UnaryElementwiseFunction<TSrc>{std::forward<TSrc>(src), UnaryFunction::Sqrt};
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100535}
536
537/** Represents the expression: `erf(\p src)`.
538 *
539 * @tparam TSrc The type of the argument.
540 * @param[in] src The argument.
541 * @return The resulting AST node.
542 */
543template <typename TSrc>
544UnaryElementwiseFunction<TSrc> erf(TSrc &&src)
545{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100546 return UnaryElementwiseFunction<TSrc>{std::forward<TSrc>(src), UnaryFunction::Erf};
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100547}
548
549/** Represents the expression: `fabs(\p src)`.
550 *
551 * @tparam TSrc The type of the argument.
552 * @param[in] src The argument.
553 * @return The resulting AST node.
554 */
555template <typename TSrc>
556UnaryElementwiseFunction<TSrc> fabs(TSrc &&src)
557{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100558 return UnaryElementwiseFunction<TSrc>{std::forward<TSrc>(src), UnaryFunction::Fabs};
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100559}
560
561/** Represents the expression: `log(\p src)`.
562 *
563 * @tparam TSrc The type of the argument.
564 * @param[in] src The argument.
565 * @return The resulting AST node.
566 */
567template <typename TSrc>
568UnaryElementwiseFunction<TSrc> log(TSrc &&src)
569{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100570 return UnaryElementwiseFunction<TSrc>{std::forward<TSrc>(src), UnaryFunction::Log};
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100571}
572
573/** Represents the expression: `round(\p src)`.
574 *
575 * @tparam TSrc The type of the argument.
576 * @param[in] src The argument.
577 * @return The resulting AST node.
578 */
579template <typename TSrc>
580UnaryElementwiseFunction<TSrc> round(TSrc &&src)
581{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100582 return UnaryElementwiseFunction<TSrc>{std::forward<TSrc>(src), UnaryFunction::Round};
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100583}
584
585/** Represents the expression: `sizeof(\p src)`.
586 *
587 * @tparam TSrc The type of the argument.
588 * @param[in] src The argument.
589 * @return The resulting AST node.
590 */
591template <typename TSrc>
592UnaryElementwiseFunction<TSrc> sizeOf(TSrc &&src)
593{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100594 return UnaryElementwiseFunction<TSrc>{std::forward<TSrc>(src), UnaryFunction::SizeOf};
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100595}
596
597// ==================================================
598// Binary elementwise functions
599// ==================================================
600
601/** AST node for binary elementwise functions.
602 *
603 * Note that both \p TFirst and \p TSecond must be operands.
604 *
605 * @tparam TFirst The type of the left argument of the function.
606 * @tparam TSecond The type of the right argument of the function.
607 */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100608template <typename TFirst,
609 typename TSecond,
610 typename = ::std::enable_if<can_be_operand<TFirst>::value && can_be_operand<TSecond>::value>>
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100611struct BinaryElementwiseFunction
612{
613 TFirst first;
614 TSecond second;
615 BinaryFunction opcode;
616};
617
618template <typename TFirst, typename TSecond>
619struct can_be_operand<BinaryElementwiseFunction<TFirst, TSecond>> : ::std::true_type
620{
621};
622
623/** Represents the function call: `max(\p first, \p second)`.
624 *
625 * @tparam TFirst The type of the first argument.
626 * @tparam TSecond The type of the second argument.
627 * @param[in] first The first argument.
628 * @param[in] second The second argument.
629 * @return The resulting AST node.
630 */
631template <typename TFirst, typename TSecond>
632BinaryElementwiseFunction<TFirst, TSecond> max(TFirst &&first, TSecond &&second)
633{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100634 return BinaryElementwiseFunction<TFirst, TSecond>{std::forward<TFirst>(first), std::forward<TSecond>(second),
635 BinaryFunction::Max};
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100636}
637
638/** Represents the function call: `min(\p first, \p second)`.
639 *
640 * @tparam TFirst The type of the first argument.
641 * @tparam TSecond The type of the second argument.
642 * @param[in] first The first argument.
643 * @param[in] second The second argument.
644 * @return The resulting AST node.
645 */
646template <typename TFirst, typename TSecond>
647BinaryElementwiseFunction<TFirst, TSecond> min(TFirst &&first, TSecond &&second)
648{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100649 return BinaryElementwiseFunction<TFirst, TSecond>{std::forward<TFirst>(first), std::forward<TSecond>(second),
650 BinaryFunction::Min};
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100651}
652
653// ==================================================
654// Ternary elementwise functions
655// ==================================================
656
657/** AST node for ternary elementwise functions.
658 *
659 * Note that \p TFirst, \p TSecond, and \p TThird all must be operands.
660 *
661 * @tparam TFirst The type of the first argument to the function.
662 * @tparam TSecond The type of the second argument to the function.
663 * @tparam TThird The type of the third argument to the function.
664 */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100665template <typename TFirst,
666 typename TSecond,
667 typename TThird,
668 typename = ::std::enable_if<can_be_operand<TFirst>::value && can_be_operand<TSecond>::value &&
669 can_be_operand<TThird>::value>>
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100670struct TernaryElementwiseFunction
671{
672 TFirst first;
673 TSecond second;
674 TThird third;
675 TernaryFunction opcode;
676};
677
678template <typename TFirst, typename TSecond, typename TThird>
679struct can_be_operand<TernaryElementwiseFunction<TFirst, TSecond, TThird>> : ::std::true_type
680{
681};
682
683/** Represents the function call: `select(\p first, \p second, \p third)`.
684 *
685 * @tparam TFirst The type of the first argument.
686 * @tparam TSecond The type of the second argument.
687 * @tparam TThird The type of the third argument.
688 * @param[in] first The first argument.
689 * @param[in] second The second argument.
690 * @param[in] third The third argument.
691 * @return The resulting AST node.
692 */
693template <typename TFirst, typename TSecond, typename TThird>
694TernaryElementwiseFunction<TFirst, TSecond, TThird> select(TFirst &&first, TSecond &&second, TThird &&third)
695{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100696 return TernaryElementwiseFunction<TFirst, TSecond, TThird>{std::forward<TFirst>(first),
697 std::forward<TSecond>(second),
698 std::forward<TThird>(third), TernaryFunction::Select};
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100699}
700
701/** Helper class used to extend a KernelWriter with additional functionality
702 * in order to make writing easier.
703 *
704 * This extension automatically handles creation of temporary variables, and
705 * allows nested function calls and operations.
706 *
707 * @tparam TWriter The type of KernelWriter to be overloaded. This must inherit from KernelWriter.
708 */
709template <class TWriter, typename = std::enable_if<std::is_base_of<KernelWriter, TWriter>::value>>
710class KernelWriterHelper : public TWriter
711{
712public:
713 using TWriter::TWriter;
714
715 // ==================================================
716 // If-statements
717 // ==================================================
718
719 // Un-hide original implementation, in case the original implementation is required.
720 using TWriter::op_if;
721
722 /** Represents the if-statement: `if(\p cond) { \p body }`.
723 *
724 * The BinaryExpression is unpacked and its components are forwarded to
725 * the underlying KernelWriter's implementation.
726 *
727 * @param[in] cond The BinaryExpression representing the condition.
728 * @param[in] body The body of the if-statement.
729 */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100730 KernelWriterHelper<TWriter> &op_if(const BinaryExpression<TileOperand &, TileOperand &> &cond,
731 const std::function<void()> &body)
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100732 {
733 TWriter::op_if(cond.lhs, cond.opcode, cond.rhs, body);
734 return *this;
735 }
736
737 /** Represents the if-statement: `if(\p cond) { \p body }`.
738 *
739 * The BinaryExpression is unpacked and its components are forwarded to
740 * the underlying KernelWriter's implementation.
741 *
742 * @param[in] cond The BinaryExpression representing the condition.
743 * @param[in] body The body of the if-statement.
744 */
745 template <typename TRight>
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100746 KernelWriterHelper<TWriter> &op_if(const BinaryExpression<TileOperand &, TRight> &cond,
747 const std::function<void()> &body)
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100748 {
749 auto &tmp1 = declare_temp_tile(cond.lhs.tile_info());
750 op_assign(tmp1, cond.rhs);
751 TWriter::op_if(cond.lhs, cond.opcode, tmp1, body);
752 return *this;
753 }
754
755 /** Represents the if-statement: `if(\p cond) { \p body }`.
756 *
757 * The BinaryExpression is unpacked and its components are forwarded to
758 * the underlying KernelWriter's implementation.
759 *
760 * @param[in] cond The BinaryExpression representing the condition.
761 * @param[in] body The body of the if-statement.
762 */
763 template <typename TLeft>
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100764 KernelWriterHelper<TWriter> &op_if(const BinaryExpression<TLeft, TileOperand &> &cond,
765 const std::function<void()> &body)
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100766 {
767 auto &tmp1 = declare_temp_tile(cond.rhs.tile_info());
768 op_assign(tmp1, cond.lhs);
769 TWriter::op_if(tmp1, cond.opcode, cond.rhs, body);
770 return *this;
771 }
772
773 // Un-hide original implementation, in case the original implementation is required.
774 using TWriter::op_else_if;
775
776 /** Represents the else-if-statement: `else if(\p cond) { \p body }`.
777 *
778 * The BinaryExpression is unpacked and its components are forwarded to
779 * the underlying KernelWriter's implementation.
780 *
781 * @param[in] cond The BinaryExpression representing the condition.
782 * @param[in] body The body of the else-if-statement.
783 */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100784 KernelWriterHelper<TWriter> &op_else_if(const BinaryExpression<TileOperand &, TileOperand &> &cond,
785 const std::function<void()> &body)
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100786 {
787 TWriter::op_else_if(cond.lhs, cond.opcode, cond.rhs, body);
788 return *this;
789 }
790
791 /** Represents the else-if-statement: `else if(\p cond) { \p body }`.
792 *
793 * The BinaryExpression is unpacked and its components are forwarded to
794 * the underlying KernelWriter's implementation.
795 *
796 * @param[in] cond The BinaryExpression representing the condition.
797 * @param[in] body The body of the else-if-statement.
798 */
799 template <typename TRight>
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100800 KernelWriterHelper<TWriter> &op_else_if(const BinaryExpression<TileOperand &, TRight> &cond,
801 const std::function<void()> &body)
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100802 {
803 auto &tmp1 = declare_temp_tile(cond.lhs.tile_info());
804 op_assign(tmp1, cond.rhs);
805 TWriter::op_else_if(cond.lhs, cond.opcode, tmp1, body);
806 return *this;
807 }
808
809 /** Represents the else-if-statement: `else if(\p cond) { \p body }`.
810 *
811 * The BinaryExpression is unpacked and its components are forwarded to
812 * the underlying KernelWriter's implementation.
813 *
814 * @param[in] cond The BinaryExpression representing the condition.
815 * @param[in] body The body of the else-if-statement.
816 */
817 template <typename TLeft>
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100818 KernelWriterHelper<TWriter> &op_else_if(const BinaryExpression<TLeft, TileOperand &> &cond,
819 const std::function<void()> &body)
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100820 {
821 auto &tmp1 = declare_temp_tile(cond.rhs.tile_info());
822 op_assign(tmp1, cond.lhs);
823 TWriter::op_else_if(tmp1, cond.opcode, cond.rhs, body);
824 return *this;
825 }
826
827 // ==================================================
828 // For-loops
829 // ==================================================
830
831 // Un-hide original implementation, in case the original implementation is required.
832 using TWriter::op_for_loop;
833
834 /** Represents the for-loop: `for(;\p cond; \p updater) { \p body }`.
835 *
836 * The BinaryExpression for the condition and the Assignment
837 * for the updater are unpacked and their components are forwarded to
838 * the underlying KernelWriter's implementation.
839 *
840 * @param[in] cond The BinaryExpression representing the condition.
841 * @param[in] updater The Assignment representing the updater.
842 * @param[in] body The body of the for-loop.
843 */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100844 void op_for_loop(const BinaryExpression<TileOperand &, TileOperand &> &cond,
845 const Assignment<TileOperand &, TileOperand &> &updater,
846 const std::function<void()> &body)
Nikolaj Jensenfab6c212023-06-27 14:13:24 +0100847 {
848 TWriter::op_for_loop(cond.lhs, cond.opcode, cond.rhs, updater.lhs, updater.opcode, updater.rhs, body);
849 }
850
851 // ==================================================
852 // Unary expressions
853 // ==================================================
854
855 // Un-hide original implementation, in case the original implementation is required.
856 using TWriter::op_assign;
857
858 /** Represents the assignment: `\p dst = \p exp`.
859 *
860 * The UnaryExpression is unpacked and its components are forwarded to
861 * the underlying KernelWriter's implementation.
862 *
863 * @param[in] dst The tile which is assigned to.
864 * @param[in] exp The UnaryExpression representing the expression to be evaluated and assigned.
865 */
866 void op_assign(const TileOperand &dst, const UnaryExpression<TileOperand &> &exp)
867 {
868 TWriter::op_unary_expression(dst, exp.opcode, exp.src);
869 }
870
871 // ==================================================
872 // Binary expressions
873 // ==================================================
874
875 /** Represents the assignment: `\p dst = \p exp`.
876 *
877 * The BinaryExpression is unpacked and its components are forwarded to
878 * the underlying KernelWriter's implementation.
879 *
880 * @param[in] dst The tile which is assigned to.
881 * @param[in] exp The BinaryExpression representing the expression to be evaluated and assigned.
882 */
883 void op_assign(const TileOperand &dst, const BinaryExpression<TileOperand &, TileOperand &> &exp)
884 {
885 TWriter::op_binary_expression(dst, exp.lhs, exp.opcode, exp.rhs);
886 }
887
888 /** Represents the assignment: `\p dst = \p exp`.
889 *
890 * The BinaryExpression is unpacked and its components are forwarded to
891 * the underlying KernelWriter's implementation.
892 *
893 * @param[in] dst The tile which is assigned to.
894 * @param[in] exp The BinaryExpression representing the expression to be evaluated and assigned.
895 */
896 template <typename TRight>
897 void op_assign(const TileOperand &dst, const BinaryExpression<TileOperand &, TRight> &exp)
898 {
899 std::cout << "Beginning assignment!" << std::endl;
900 auto &tmp1 = declare_temp_tile(dst.tile_info());
901 op_assign(tmp1, exp.rhs);
902 TWriter::op_binary_expression(dst, exp.lhs, exp.opcode, tmp1);
903 }
904
905 /** Represents the assignment: `\p dst = \p exp`.
906 *
907 * The BinaryExpression is unpacked and its components are forwarded to
908 * the underlying KernelWriter's implementation.
909 *
910 * @param[in] dst The tile which is assigned to.
911 * @param[in] exp The BinaryExpression representing the expression to be evaluated and assigned.
912 */
913 template <typename TLeft>
914 void op_assign(const TileOperand &dst, const BinaryExpression<TLeft, TileOperand &> &exp)
915 {
916 std::cout << "Beginning assignment!" << std::endl;
917 auto &tmp1 = declare_temp_tile(dst.tile_info());
918 op_assign(tmp1, exp.lhs);
919 TWriter::op_binary_expression(dst, tmp1, exp.opcode, exp.rhs);
920 }
921
922 /** Represents the assignment: `\p dst = \p exp`.
923 *
924 * The BinaryExpression is unpacked and its components are forwarded to
925 * the underlying KernelWriter's implementation.
926 *
927 * @param[in] dst The tile which is assigned to.
928 * @param[in] exp The BinaryExpression representing the expression to be evaluated and assigned.
929 */
930 template <typename TLeft, typename TRight>
931 void op_assign(const TileOperand &dst, const BinaryExpression<TLeft, TRight> &exp)
932 {
933 auto &tmp1 = declare_temp_tile(dst.tile_info());
934 auto &tmp2 = declare_temp_tile(dst.tile_info());
935 op_assign(tmp1, exp.lhs);
936 op_assign(tmp2, exp.rhs);
937 TWriter::op_binary_expression(dst, tmp1, exp.opcode, tmp2);
938 }
939
940 // ==================================================
941 // Unary elementwise functions
942 // ==================================================
943
944 /** Represents the assignment: `\p dst = \p exp`.
945 *
946 * The UnaryElementwiseFunction is unpacked and its components are forwarded to
947 * the underlying KernelWriter's implementation.
948 *
949 * @param[in] dst The tile which is assigned to.
950 * @param[in] exp The UnaryElementwiseFunction representing the expression to be evaluated and assigned.
951 */
952 void op_assign(const TileOperand &dst, const UnaryElementwiseFunction<TileOperand &> &exp)
953 {
954 TWriter::op_unary_elementwise_function(dst, exp.opcode, exp.src);
955 }
956
957 /** Represents the assignment: `\p dst = \p exp`.
958 *
959 * The UnaryElementwiseFunction is unpacked and its components are forwarded to
960 * the underlying KernelWriter's implementation.
961 *
962 * @param[in] dst The tile which is assigned to.
963 * @param[in] exp The UnaryElementwiseFunction representing the expression to be evaluated and assigned.
964 */
965 template <typename TArg>
966 void op_assign(const TileOperand &dst, const UnaryElementwiseFunction<TArg> &exp)
967 {
968 auto &tmp1 = declare_temp_tile(dst.tile_info());
969 op_assign(tmp1, exp.lhs);
970 TWriter::op_unary_elementwise_function(dst, exp.opcode, tmp1);
971 }
972
973 // ==================================================
974 // Binary elementwise functions
975 // ==================================================
976
977 /** Represents the assignment: `\p dst = \p exp`.
978 *
979 * The BinaryElementwiseFunction is unpacked and its components are forwarded to
980 * the underlying KernelWriter's implementation.
981 *
982 * @param[in] dst The tile which is assigned to.
983 * @param[in] exp The BinaryElementwiseFunction representing the expression to be evaluated and assigned.
984 */
985 void op_assign(const TileOperand &dst, const BinaryElementwiseFunction<TileOperand &, TileOperand &> &exp)
986 {
987 TWriter::op_binary_elementwise_function(dst, exp.opcode, exp.first, exp.second);
988 }
989
990 /** Represents the assignment: `\p dst = \p exp`.
991 *
992 * The BinaryElementwiseFunction is unpacked and its components are forwarded to
993 * the underlying KernelWriter's implementation.
994 *
995 * @param[in] dst The tile which is assigned to.
996 * @param[in] exp The BinaryElementwiseFunction representing the expression to be evaluated and assigned.
997 */
998 template <typename TRight>
999 void op_assign(const TileOperand &dst, const BinaryElementwiseFunction<TileOperand &, TRight> &exp)
1000 {
1001 auto &tmp1 = declare_temp_tile(dst.tile_info());
1002 op_assign(tmp1, exp.second);
1003 TWriter::op_binary_elementwise_function(dst, exp.opcode, exp.first, tmp1);
1004 }
1005
1006 /** Represents the assignment: `\p dst = \p exp`.
1007 *
1008 * The BinaryElementwiseFunction is unpacked and its components are forwarded to
1009 * the underlying KernelWriter's implementation.
1010 *
1011 * @param[in] dst The tile which is assigned to.
1012 * @param[in] exp The BinaryElementwiseFunction representing the expression to be evaluated and assigned.
1013 */
1014 template <typename TLeft>
1015 void op_assign(const TileOperand &dst, const BinaryElementwiseFunction<TLeft, TileOperand &> &exp)
1016 {
1017 auto &tmp1 = declare_temp_tile(dst.tile_info());
1018 op_assign(tmp1, exp.first);
1019 TWriter::op_binary_elementwise_function(dst, exp.opcode, tmp1, exp.second);
1020 }
1021
1022 /** Represents the assignment: `\p dst = \p exp`.
1023 *
1024 * The BinaryElementwiseFunction is unpacked and its components are forwarded to
1025 * the underlying KernelWriter's implementation.
1026 *
1027 * @param[in] dst The tile which is assigned to.
1028 * @param[in] exp The BinaryElementwiseFunction representing the expression to be evaluated and assigned.
1029 */
1030 template <typename TLeft, typename TRight>
1031 void op_assign(const TileOperand &dst, const BinaryElementwiseFunction<TLeft, TRight> &exp)
1032 {
1033 auto &tmp1 = declare_temp_tile(dst.tile_info());
1034 auto &tmp2 = declare_temp_tile(dst.tile_info());
1035 op_assign(tmp1, exp.first);
1036 op_assign(tmp2, exp.second);
1037 TWriter::op_binary_elementwise_function(dst, exp.opcode, tmp1, tmp2);
1038 }
1039
1040 // ==================================================
1041 // Ternary elementwise functions
1042 // ==================================================
1043
1044 /** Represents the assignment: `\p dst = \p exp`.
1045 *
1046 * The TernaryElementwiseFunction is unpacked and its components are forwarded to
1047 * the underlying KernelWriter's implementation.
1048 *
1049 * @param[in] dst The tile which is assigned to.
1050 * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned.
1051 */
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001052 void op_assign(const TileOperand &dst,
1053 const TernaryElementwiseFunction<TileOperand &, TileOperand &, TileOperand &> &exp)
Nikolaj Jensenfab6c212023-06-27 14:13:24 +01001054 {
1055 TWriter::op_ternary_elementwise_function(dst, exp.opcode, exp.first, exp.second, exp.third);
1056 }
1057
1058 /** Represents the assignment: `\p dst = \p exp`.
1059 *
1060 * The TernaryElementwiseFunction is unpacked and its components are forwarded to
1061 * the underlying KernelWriter's implementation.
1062 *
1063 * @param[in] dst The tile which is assigned to.
1064 * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned.
1065 */
1066 template <typename TFirst>
1067 void op_assign(const TileOperand &dst, const TernaryElementwiseFunction<TFirst, TileOperand &, TileOperand &> &exp)
1068 {
1069 auto &tmp1 = declare_temp_tile(dst.tile_info());
1070 op_assign(tmp1, exp.first);
1071 TWriter::op_ternary_elementwise_function(dst, exp.opcode, tmp1, exp.second, exp.third);
1072 }
1073
1074 /** Represents the assignment: `\p dst = \p exp`.
1075 *
1076 * The TernaryElementwiseFunction is unpacked and its components are forwarded to
1077 * the underlying KernelWriter's implementation.
1078 *
1079 * @param[in] dst The tile which is assigned to.
1080 * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned.
1081 */
1082 template <typename TSecond>
1083 void op_assign(const TileOperand &dst, const TernaryElementwiseFunction<TileOperand &, TSecond, TileOperand &> &exp)
1084 {
1085 auto &tmp1 = declare_temp_tile(dst.tile_info());
1086 op_assign(tmp1, exp.second);
1087 TWriter::op_ternary_elementwise_function(dst, exp.opcode, exp.first, tmp1, exp.third);
1088 }
1089
1090 /** Represents the assignment: `\p dst = \p exp`.
1091 *
1092 * The TernaryElementwiseFunction is unpacked and its components are forwarded to
1093 * the underlying KernelWriter's implementation.
1094 *
1095 * @param[in] dst The tile which is assigned to.
1096 * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned.
1097 */
1098 template <typename TThird>
1099 void op_assign(const TileOperand &dst, const TernaryElementwiseFunction<TileOperand &, TileOperand &, TThird> &exp)
1100 {
1101 auto &tmp1 = declare_temp_tile(dst.tile_info());
1102 op_assign(tmp1, exp.third);
1103 TWriter::op_ternary_elementwise_function(dst, exp.opcode, exp.first, exp.second, tmp1);
1104 }
1105
1106 /** Represents the assignment: `\p dst = \p exp`.
1107 *
1108 * The TernaryElementwiseFunction is unpacked and its components are forwarded to
1109 * the underlying KernelWriter's implementation.
1110 *
1111 * @param[in] dst The tile which is assigned to.
1112 * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned.
1113 */
1114 template <typename TFirst, typename TSecond>
1115 void op_assign(const TileOperand &dst, const TernaryElementwiseFunction<TFirst, TSecond, TileOperand &> &exp)
1116 {
1117 auto &tmp1 = declare_temp_tile(dst.tile_info());
1118 auto &tmp2 = declare_temp_tile(dst.tile_info());
1119 op_assign(tmp1, exp.first);
1120 op_assign(tmp2, exp.second);
1121 TWriter::op_ternary_elementwise_function(dst, exp.opcode, tmp1, tmp2, exp.third);
1122 }
1123
1124 /** Represents the assignment: `\p dst = \p exp`.
1125 *
1126 * The TernaryElementwiseFunction is unpacked and its components are forwarded to
1127 * the underlying KernelWriter's implementation.
1128 *
1129 * @param[in] dst The tile which is assigned to.
1130 * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned.
1131 */
1132 template <typename TFirst, typename TThird>
1133 void op_assign(const TileOperand &dst, const TernaryElementwiseFunction<TFirst, TileOperand &, TThird> &exp)
1134 {
1135 auto &tmp1 = declare_temp_tile(dst.tile_info());
1136 auto &tmp2 = declare_temp_tile(dst.tile_info());
1137 op_assign(tmp1, exp.first);
1138 op_assign(tmp2, exp.third);
1139 TWriter::op_ternary_elementwise_function(dst, exp.opcode, tmp1, exp.second, tmp2);
1140 }
1141
1142 /** Represents the assignment: `\p dst = \p exp`.
1143 *
1144 * The TernaryElementwiseFunction is unpacked and its components are forwarded to
1145 * the underlying KernelWriter's implementation.
1146 *
1147 * @param[in] dst The tile which is assigned to.
1148 * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned.
1149 */
1150 template <typename TSecond, typename TThird>
1151 void op_assign(const TileOperand &dst, const TernaryElementwiseFunction<TileOperand &, TSecond, TThird> &exp)
1152 {
1153 auto &tmp1 = declare_temp_tile(dst.tile_info());
1154 auto &tmp2 = declare_temp_tile(dst.tile_info());
1155 op_assign(tmp1, exp.second);
1156 op_assign(tmp2, exp.third);
1157 TWriter::op_ternary_elementwise_function(dst, exp.opcode, exp.first, tmp1, tmp2);
1158 }
1159
1160 /** Represents the assignment: `\p dst = \p exp`.
1161 *
1162 * The TernaryElementwiseFunction is unpacked and its components are forwarded to
1163 * the underlying KernelWriter's implementation.
1164 *
1165 * @param[in] dst The tile which is assigned to.
1166 * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned.
1167 */
1168 template <typename TFirst, typename TSecond, typename TThird>
1169 void op_assign(const TileOperand &dst, const TernaryElementwiseFunction<TFirst, TSecond, TThird> &exp)
1170 {
1171 auto &tmp1 = declare_temp_tile(dst.tile_info(), dst.tile_info(), dst.tile_info());
1172 auto &tmp2 = declare_temp_tile(dst.tile_info());
1173 auto &tmp3 = declare_temp_tile(dst.tile_info());
1174 op_assign(tmp1, exp.first);
1175 op_assign(tmp2, exp.second);
1176 op_assign(tmp3, exp.third);
1177 TWriter::op_ternary_elementwise_function(dst, exp.opcode, tmp1, tmp2, tmp3);
1178 }
1179
1180 // ==================================================
1181 // Assignments
1182 // ==================================================
1183
1184 /** Represents the assignment: `\p lhs += \p rhs` or `\p lhs -= \p rhs`.
1185 *
1186 * The Assignment is unpacked and its components are forwarded to
1187 * the underlying KernelWriter's implementation.
1188 *
1189 * @param[in] exp The Assignment representing the expression to be evaluated.
1190 */
1191 void op_assign(const Assignment<TileOperand &, TileOperand &> &exp)
1192 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001193 if (exp.opcode == AssignmentOp::Increment)
Nikolaj Jensenfab6c212023-06-27 14:13:24 +01001194 {
1195 TWriter::op_binary_expression(exp.lhs, exp.lhs, BinaryOp::Add, exp.rhs);
1196 }
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001197 else if (exp.opcode == AssignmentOp::Decrement)
Nikolaj Jensenfab6c212023-06-27 14:13:24 +01001198 {
1199 TWriter::op_binary_expression(exp.lhs, exp.lhs, BinaryOp::Sub, exp.rhs);
1200 }
1201 }
1202
1203 /** Represents the assignment: `\p lhs += \p rhs` or `\p lhs -= \p rhs`.
1204 *
1205 * The Assignment is unpacked and its components are forwarded to
1206 * the underlying KernelWriter's implementation.
1207 *
1208 * @tparam TRight The type of the RHS of the assignment.
1209 * @param[in] exp The Assignment representing the expression to be evaluated.
1210 */
1211 template <typename TRight>
1212 void op_assign(const Assignment<TileOperand &, TRight> &exp)
1213 {
1214 auto &tmp1 = declare_temp_tile(exp.lhs.tile_info());
1215 op_assign(tmp1, exp.rhs);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001216 op_assign(Assignment<TileOperand &, TileOperand &>{exp.lhs, tmp1, exp.opcode});
Nikolaj Jensenfab6c212023-06-27 14:13:24 +01001217 }
1218
1219private:
1220 unsigned int temp_var_counter = 0;
1221
1222 /** Return the current counter value, then increment it.
1223 *
1224 * @return The current counter value.
1225 */
1226 int next_ctr()
1227 {
1228 return temp_var_counter++;
1229 }
1230
1231 /** Gets the next temporary variable counter value,
1232 * and returns a suitable temporary variable name.
1233 *
1234 * @return A temporary variable name.
1235 */
1236 std::string next_tmp_var_name()
1237 {
1238 return "tmp_" + std::to_string(next_ctr());
1239 }
1240
1241 /** Returns the argument.
1242 *
1243 * Used for recursion with the variadic function version of this function.
1244 *
1245 * @param[in] arg The TileInfo to return.
1246 * @return The \p arg.
1247 */
1248 TileInfo get_largest_size(const TileInfo &arg)
1249 {
1250 return arg;
1251 }
1252
1253 /** Returns a TileInfo object where the size in each dimension (width, height) is the largest
1254 * of either TileInfo argument in the corresponding dimension.
1255 *
1256 * @tparam TOps Must be of TileInfo type.
1257 * @param[in] first A TileInfo object.
1258 * @param[in] second A TileInfo object.
1259 * @param[in] ops A number of TileInfo objects.
1260 * @return A TileInfo object which represents the largest shape in each dimension across the arguments.
1261 */
1262 template <typename... TOps, typename = ::std::enable_if_t<std::is_same<TOps..., TileInfo>::value>>
1263 TileInfo get_largest_size(const TileInfo &first, const TileInfo &second, const TOps &...ops)
1264 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +01001265 TileInfo largest = {first.data_type(), std::max(first.width(), second.width()),
1266 std::max(first.height(), second.height())};
Nikolaj Jensenfab6c212023-06-27 14:13:24 +01001267 return get_largest_size(largest, ops...);
1268 }
1269
1270 /** Helper function to define a suitable TileOperand with appropriate TileInfo
1271 * such that broadcasting is taken into account, based on the arguments provided.
1272 *
1273 * @tparam TArgs Must be of TileInfo type.
1274 * @param[in] args A number of TileInfo which determine the shape of the TileOperand to declare.
1275 * @return A newly created TileOperand.
1276 */
1277 template <typename... TArgs, typename = ::std::enable_if_t<std::is_same<TArgs..., TileInfo>::value>>
1278 TileOperand &declare_temp_tile(const TArgs &...args)
1279 {
1280 return TWriter::declare_tile(next_tmp_var_name().c_str(), get_largest_size(args...));
1281 }
1282};
1283
1284} // namespace ckw
1285
1286#endif // CKW_INCLUDE_CKW_KERNELWRITERHELPER_H