blob: a8be859680ca03885c1c428cda1df1515678a65b [file] [log] [blame]
Nikolaj Jensenfab6c212023-06-27 14:13:24 +01001/*
2 * Copyright (c) 2023 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#ifndef CKW_INCLUDE_CKW_KERNELWRITERHELPER_H
26#define CKW_INCLUDE_CKW_KERNELWRITERHELPER_H
27
28#include "ckw/KernelWriter.h"
29#include "ckw/TensorOperand.h"
30#include "ckw/TileOperand.h"
31
32#include <iostream>
33#include <type_traits>
34
35#include <iostream>
36
37/*
38 * By including this header file you will be able to supplement the default
39 * Compute Kernel Writer API with additional syntax to help ease the use of CKW.
40 *
41 * To use the KernelWriterHelper you need to wrap your instance of KernelWriter
42 * (or any class deriving from KernelWriter):
43 * KernelWriterHelper<KernelWriter> writer;
44 * The resulting writer object comprises the original KernelWriter
45 * functionality (drop-in replacement), but extends the syntax as follows.
46 *
47 * Common functions/operators have natural syntax:
48 * 1. Unary expressions:
49 * writer.op_assign(dst, !src); // Logical NOT
50 * writer.op_assign(dst, ~src); // Bitwise NOT
51 *
52 * 2. Binary expressions:
53 * writer.op_assign(dst, lhs + rhs); // Addition
54 * writer.op_assign(dst, lhs - rhs); // Subtraction
55 * writer.op_assign(dst, lhs * rhs); // Multiplication
56 * writer.op_assign(dst, lhs / rhs); // Division
57 * writer.op_assign(dst, lhs % rhs); // Modulo
58 * writer.op_assign(dst, lhs == rhs); // Equality
59 * writer.op_assign(dst, lhs < rhs); // Less-than
60 * writer.op_assign(dst, lhs <= rhs); // Less-than-or-equal
61 * writer.op_assign(dst, lhs > rhs); // Greater-than
62 * writer.op_assign(dst, lhs >= rhs); // Greater-than-or-equal
63 * writer.op_assign(dst, lhs ^ rhs); // Bitwise XOR
64 * writer.op_assign(dst, logical_and(lhs, rhs)); // Logical AND
65 * writer.op_assign(dst, logical_or(lhs, rhs)); // Logical OR
66 *
67 * 3. Unary elementwise functions:
68 * writer.op_assign(dst, exp(src)); // Exponent
69 * writer.op_assign(dst, tanh(src)); // Hyperbolic tangent
70 * writer.op_assign(dst, sqrt(src)); // Square root
71 * writer.op_assign(dst, erf(src)); // Error function
72 * writer.op_assign(dst, fabs(src)); // Absolute of floating-point number
73 * writer.op_assign(dst, log(src)); // Natural logarithm
74 * writer.op_assign(dst, round(src)); // Round
75 * writer.op_assign(dst, sizeOf(src)); // sizeof
76 *
77 * 4. Binary elementwise functions:
78 * writer.op_assign(dst, max(first, second)); // Max
79 * writer.op_assign(dst, min(first, second)); // Min
80 *
81 * 5. Ternary elementwise functions:
82 * writer.op_assign(dst, select(first, second, third)); // Select
83 *
84 * NOTE: All the above examples support nesting, so you could write
85 * something like: writer.op_assign(dst, src * (log(arg) + sqrt(abs(arg)));
86 *
87 *
88 * 6. If-statements. The preceding syntax also allows easier writing of if-statements:
89 * writer.op_if(<cond>, <body>);
90 *
91 * For example:
92 * writer.op_if(exp(first_arg) == dst, [&]{
93 * //...
94 * }).op_else_if(exp(first_arg) > dst, [&]{
95 * //...
96 * }).op_else([&] {
97 * //...
98 * });
99 *
100 * 7. For-loops. A similar syntax exists for for-loops:
101 * writer.op_for_loop(<cond>, <updater>, <body>);
102 *
103 * For example:
104 * writer.op_for_loop(index < limit, index += step, [&]{
105 * //...
106 * });
107 *
108 * NOTE: There are limitations on the for-loop <cond> and <updater> parameters.
109 * In neither the <cond> (Binary expression) or <updater> (Increment/Decrement)
110 * is it allowed to use nesting. For example, `(index + other) < limit` and
111 * `index < round(limit)` are invalid <cond> parameters. This is because the
112 * semantics of for-loops rely on the condition being evaluated at every iteration,
113 * but as temporary variables might be defined for nested expressions the semantics
114 * cannot be guaranteed.
115 */
116
117namespace ckw
118{
119
120// ==================================================
121// Type traits
122// ==================================================
123
124/** Specifies if the type can be used as an operand for functions (e.g. max), operations (e.g. *), or assignments. */
125template <typename T>
126struct can_be_operand : ::std::false_type
127{
128};
129
130/** Specifies if the type can be assigned/written to. */
131template <typename T>
132struct can_be_assigned : ::std::false_type
133{
134};
135
136template <>
137struct can_be_operand<TileOperand &> : ::std::true_type
138{
139};
140
141template <>
142struct can_be_assigned<TileOperand &> : ::std::true_type
143{
144};
145
146// ==================================================
147// Assignment
148// ==================================================
149
150/** AST node for assignments.
151 *
152 * Note that \p TRight must be an operand, and \p TLeft must be assignable.
153 *
154 * @tparam TLeft The type of the destination of the assignment.
155 * @tparam TRight The type of the source assigned to the destination.
156 */
157template <typename TLeft, typename TRight, typename = ::std::enable_if<can_be_operand<TRight>::value && can_be_assigned<TLeft>::value>>
158struct Assignment
159{
160 TLeft lhs;
161 TRight rhs;
162 AssignmentOp opcode;
163};
164
165/** Represents the expression: `\p lhs += \p rhs`.
166 *
167 * @tparam TLeft The type of the LHS of the assignment.
168 * @tparam TRight The type of the RHS of the assignment.
169 * @param[in] lhs The LHS of the assignment.
170 * @param[in] rhs The RHS of the assignment.
171 * @return The resulting AST node.
172 */
173template <typename TLeft, typename TRight>
174inline Assignment<TLeft, TRight> operator+=(TLeft &&lhs, TRight &&rhs)
175{
176 return Assignment<TLeft, TRight>{ std::forward<TLeft>(lhs), std::forward<TRight>(rhs), AssignmentOp::Increment };
177}
178
179/** Represents the expression: `\p lhs -= \p rhs`.
180 *
181 * @tparam TLeft The type of the LHS of the assignment.
182 * @tparam TRight The type of the RHS of the assignment.
183 * @param[in] lhs The LHS of the assignment.
184 * @param[in] rhs The RHS of the assignment.
185 * @return The resulting AST node.
186 */
187template <typename TLeft, typename TRight>
188inline Assignment<TLeft, TRight> operator-=(TLeft &&lhs, TRight &&rhs)
189{
190 return Assignment<TLeft, TRight>{ std::forward<TLeft>(lhs), std::forward<TRight>(rhs), AssignmentOp::Decrement };
191}
192
193// ==================================================
194// Unary expression
195// ==================================================
196
197/** AST node for unary expressions.
198 *
199 * Note that \p TSrc must be an operand.
200 *
201 * @tparam TSrc The type of the argument to the expression.
202 */
203template <typename TSrc, typename = ::std::enable_if<can_be_operand<TSrc>::value>>
204struct UnaryExpression
205{
206 TSrc src;
207 UnaryOp opcode;
208};
209
210template <typename TLeft>
211struct can_be_operand<UnaryExpression<TLeft>> : ::std::true_type
212{
213};
214
215/** Represents the expression: `!\p src`.
216 *
217 * @tparam TSrc The type of the argument.
218 * @param[in] src The argument.
219 * @return The resulting AST node.
220 */
221template <typename TSrc>
222inline UnaryExpression<TSrc> operator!(TSrc &&src)
223{
224 return UnaryExpression<TSrc>{ std::forward<TSrc>(src), UnaryOp::LogicalNot };
225}
226
227/** Represents the expression: `~\p src`.
228 *
229 * @tparam TSrc The type of the argument.
230 * @param[in] src The argument.
231 * @return The resulting AST node.
232 */
233template <typename TSrc>
234inline UnaryExpression<TSrc> operator~(TSrc &&src)
235{
236 return UnaryExpression<TSrc>{ std::forward<TSrc>(src), UnaryOp::BitwiseNot };
237}
238
239// ==================================================
240// Binary expressions
241// ==================================================
242
243/** AST node for binary expressions.
244 *
245 * Note that both \p TLeft and \p TRight must be operands.
246 *
247 * @tparam TLeft The type of the left argument of the expression.
248 * @tparam TRight The type of the right argument of the expression.
249 */
250template <typename TLeft, typename TRight, typename = ::std::enable_if_t<can_be_operand<TLeft>::value && can_be_operand<TRight>::value>>
251struct BinaryExpression
252{
253 TLeft lhs;
254 TRight rhs;
255 BinaryOp opcode;
256};
257
258template <typename TLeft, typename TRight>
259struct can_be_operand<BinaryExpression<TLeft, TRight>> : ::std::true_type
260{
261};
262
263/** Represents the expression: `\p lhs + \p rhs`.
264 *
265 * @tparam TLeft The type of the LHS of the expression.
266 * @tparam TRight The type of the RHS of the expression.
267 * @param[in] lhs The LHS of the expression.
268 * @param[in] rhs The RHS of the expression.
269 * @return The resulting AST node.
270 */
271template <typename TLeft, typename TRight>
272inline BinaryExpression<TLeft, TRight> operator+(TLeft &&lhs, TRight &&rhs)
273{
274 return BinaryExpression<TLeft, TRight>{ std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::Add };
275}
276
277/** Represents the expression: `\p lhs - \p rhs`.
278 *
279 * @tparam TLeft The type of the LHS of the expression.
280 * @tparam TRight The type of the RHS of the expression.
281 * @param[in] lhs The LHS of the expression.
282 * @param[in] rhs The RHS of the expression.
283 * @return The resulting AST node.
284 */
285template <typename TLeft, typename TRight>
286inline BinaryExpression<TLeft, TRight> operator-(TLeft &&lhs, TRight &&rhs)
287{
288 return BinaryExpression<TLeft, TRight>{ std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::Sub };
289}
290
291/** Represents the expression: `\p lhs * \p rhs`.
292 *
293 * @tparam TLeft The type of the LHS of the expression.
294 * @tparam TRight The type of the RHS of the expression.
295 * @param[in] lhs The LHS of the expression.
296 * @param[in] rhs The RHS of the expression.
297 * @return The resulting AST node.
298 */
299template <typename TLeft, typename TRight>
300inline BinaryExpression<TLeft, TRight> operator*(TLeft &&lhs, TRight &&rhs)
301{
302 return BinaryExpression<TLeft, TRight>{ std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::Mul };
303}
304
305/** Represents the expression: `\p lhs / \p rhs`.
306 *
307 * @tparam TLeft The type of the LHS of the expression.
308 * @tparam TRight The type of the RHS of the expression.
309 * @param[in] lhs The LHS of the expression.
310 * @param[in] rhs The RHS of the expression.
311 * @return The resulting AST node.
312 */
313template <typename TLeft, typename TRight>
314inline BinaryExpression<TLeft, TRight> operator/(TLeft &&lhs, TRight &&rhs)
315{
316 return BinaryExpression<TLeft, TRight>{ std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::Div };
317}
318
319/** Represents the expression: `\p lhs % \p rhs`.
320 *
321 * @tparam TLeft The type of the LHS of the expression.
322 * @tparam TRight The type of the RHS of the expression.
323 * @param[in] lhs The LHS of the expression.
324 * @param[in] rhs The RHS of the expression.
325 * @return The resulting AST node.
326 */
327template <typename TLeft, typename TRight>
328inline BinaryExpression<TLeft, TRight> operator%(TLeft &&lhs, TRight &&rhs)
329{
330 return BinaryExpression<TLeft, TRight>{ std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::Mod };
331}
332
333/** Represents the expression: `\p lhs == \p rhs`.
334 *
335 * @tparam TLeft The type of the LHS of the expression.
336 * @tparam TRight The type of the RHS of the expression.
337 * @param[in] lhs The LHS of the expression.
338 * @param[in] rhs The RHS of the expression.
339 * @return The resulting AST node.
340 */
341template <typename TLeft, typename TRight>
342inline BinaryExpression<TLeft, TRight> operator==(TLeft &&lhs, TRight &&rhs)
343{
344 return BinaryExpression<TLeft, TRight>{ std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::Equal };
345}
346
347/** Represents the expression: `\p lhs < \p rhs`.
348 *
349 * @tparam TLeft The type of the LHS of the expression.
350 * @tparam TRight The type of the RHS of the expression.
351 * @param[in] lhs The LHS of the expression.
352 * @param[in] rhs The RHS of the expression.
353 * @return The resulting AST node.
354 */
355template <typename TLeft, typename TRight>
356inline BinaryExpression<TLeft, TRight> operator<(TLeft &&lhs, TRight &&rhs)
357{
358 return BinaryExpression<TLeft, TRight>{ std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::Less };
359}
360
361/** Represents the expression: `\p lhs <= \p rhs`.
362 *
363 * @tparam TLeft The type of the LHS of the expression.
364 * @tparam TRight The type of the RHS of the expression.
365 * @param[in] lhs The LHS of the expression.
366 * @param[in] rhs The RHS of the expression.
367 * @return The resulting AST node.
368 */
369template <typename TLeft, typename TRight>
370inline BinaryExpression<TLeft, TRight> operator<=(TLeft &&lhs, TRight &&rhs)
371{
372 return BinaryExpression<TLeft, TRight>{ std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::LessEqual };
373}
374
375/** Represents the expression: `\p lhs > \p rhs`.
376 *
377 * @tparam TLeft The type of the LHS of the expression.
378 * @tparam TRight The type of the RHS of the expression.
379 * @param[in] lhs The LHS of the expression.
380 * @param[in] rhs The RHS of the expression.
381 * @return The resulting AST node.
382 */
383template <typename TLeft, typename TRight>
384inline BinaryExpression<TLeft, TRight> operator>(TLeft &&lhs, TRight &&rhs)
385{
386 return BinaryExpression<TLeft, TRight>{ std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::Greater };
387}
388
389/** Represents the expression: `\p lhs >= \p rhs`.
390 *
391 * @tparam TLeft The type of the LHS of the expression.
392 * @tparam TRight The type of the RHS of the expression.
393 * @param[in] lhs The LHS of the expression.
394 * @param[in] rhs The RHS of the expression.
395 * @return The resulting AST node.
396 */
397template <typename TLeft, typename TRight>
398inline BinaryExpression<TLeft, TRight> operator>=(TLeft &&lhs, TRight &&rhs)
399{
400 return BinaryExpression<TLeft, TRight>{ std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::GreaterEqual };
401}
402
403/** Represents the expression: `\p lhs ^ \p rhs`.
404 *
405 * @tparam TLeft The type of the LHS of the expression.
406 * @tparam TRight The type of the RHS of the expression.
407 * @param[in] lhs The LHS of the expression.
408 * @param[in] rhs The RHS of the expression.
409 * @return The resulting AST node.
410 */
411template <typename TLeft, typename TRight>
412inline BinaryExpression<TLeft, TRight> operator^(TLeft &&lhs, TRight &&rhs)
413{
414 return BinaryExpression<TLeft, TRight>{ std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::BitwiseXOR };
415}
416
417/** Represents the expression: `\p lhs && \p rhs`.
418 *
419 * @tparam TLeft The type of the LHS of the expression.
420 * @tparam TRight The type of the RHS of the expression.
421 * @param[in] lhs The LHS of the expression.
422 * @param[in] rhs The RHS of the expression.
423 * @return The resulting AST node.
424 */
425template <typename TLeft, typename TRight>
426inline BinaryExpression<TLeft, TRight> logical_and(TLeft &&lhs, TRight &&rhs)
427{
428 return BinaryExpression<TLeft, TRight>{ std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::LogicalAnd };
429}
430
431/** Represents the expression: `\p lhs && \p rhs`.
432 *
433 * @tparam TLeft The type of the LHS of the expression.
434 * @tparam TRight The type of the RHS of the expression.
435 * @param[in] lhs The LHS of the expression.
436 * @param[in] rhs The RHS of the expression.
437 * @return The resulting AST node.
438 */
439template <typename TLeft, typename TRight, typename... TOps>
440inline BinaryExpression<BinaryExpression<TLeft, TRight>, TOps...> logical_and(TLeft &&lhs, TRight &&rhs, TOps &&...ops)
441{
442 return logical_and(
443 BinaryExpression<TLeft, TRight>{ std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::LogicalAnd },
444 std::forward<TOps>(ops)...);
445}
446
447/** Represents the expression: `\p lhs || \p rhs`.
448 *
449 * @tparam TLeft The type of the LHS of the expression.
450 * @tparam TRight The type of the RHS of the expression.
451 * @param[in] lhs The LHS of the expression.
452 * @param[in] rhs The RHS of the expression.
453 * @return The resulting AST node.
454 */
455template <typename TLeft, typename TRight>
456inline BinaryExpression<TLeft, TRight> logical_or(TLeft &&lhs, TRight &&rhs)
457{
458 return BinaryExpression<TLeft, TRight>{ std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::LogicalOr };
459}
460
461/** Represents the expression: `\p lhs || \p rhs`.
462 *
463 * @tparam TLeft The type of the LHS of the expression.
464 * @tparam TRight The type of the RHS of the expression.
465 * @param[in] lhs The LHS of the expression.
466 * @param[in] rhs The RHS of the expression.
467 * @return The resulting AST node.
468 */
469template <typename TLeft, typename TRight, typename... TOps>
470inline BinaryExpression<BinaryExpression<TLeft, TRight>, TOps...> logical_or(TLeft &&lhs, TRight &&rhs, TOps &&...ops)
471{
472 return logical_or(
473 BinaryExpression<TLeft, TRight>{ std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::LogicalOr },
474 std::forward<TOps>(ops)...);
475}
476
477// ==================================================
478// Unary elementwise functions
479// ==================================================
480
481/** AST node for unary elementwise functions.
482 *
483 * Note that \p TSrc must be an operand.
484 *
485 * @tparam TSrc The type of the argument to the function.
486 */
487template <typename TSrc, typename = ::std::enable_if<can_be_operand<TSrc>::value>>
488struct UnaryElementwiseFunction
489{
490 TSrc src;
491 UnaryFunction opcode;
492};
493
494template <typename TLeft>
495struct can_be_operand<UnaryElementwiseFunction<TLeft>> : ::std::true_type
496{
497};
498
499/** Represents the expression: `exp(\p src)`.
500 *
501 * @tparam TSrc The type of the argument.
502 * @param[in] src The argument.
503 * @return The resulting AST node.
504 */
505template <typename TSrc>
506UnaryElementwiseFunction<TSrc> exp(TSrc &&src)
507{
508 return UnaryElementwiseFunction<TSrc>{ std::forward<TSrc>(src), UnaryFunction::Exp };
509}
510
511/** Represents the expression: `tanh(\p src)`.
512 *
513 * @tparam TSrc The type of the argument.
514 * @param[in] src The argument.
515 * @return The resulting AST node.
516 */
517template <typename TSrc>
518UnaryElementwiseFunction<TSrc> tanh(TSrc &&src)
519{
520 return UnaryElementwiseFunction<TSrc>{ std::forward<TSrc>(src), UnaryFunction::Tanh };
521}
522
523/** Represents the expression: `sqrt(\p src)`.
524 *
525 * @tparam TSrc The type of the argument.
526 * @param[in] src The argument.
527 * @return The resulting AST node.
528 */
529template <typename TSrc>
530UnaryElementwiseFunction<TSrc> sqrt(TSrc &&src)
531{
532 return UnaryElementwiseFunction<TSrc>{ std::forward<TSrc>(src), UnaryFunction::Sqrt };
533}
534
535/** Represents the expression: `erf(\p src)`.
536 *
537 * @tparam TSrc The type of the argument.
538 * @param[in] src The argument.
539 * @return The resulting AST node.
540 */
541template <typename TSrc>
542UnaryElementwiseFunction<TSrc> erf(TSrc &&src)
543{
544 return UnaryElementwiseFunction<TSrc>{ std::forward<TSrc>(src), UnaryFunction::Erf };
545}
546
547/** Represents the expression: `fabs(\p src)`.
548 *
549 * @tparam TSrc The type of the argument.
550 * @param[in] src The argument.
551 * @return The resulting AST node.
552 */
553template <typename TSrc>
554UnaryElementwiseFunction<TSrc> fabs(TSrc &&src)
555{
556 return UnaryElementwiseFunction<TSrc>{ std::forward<TSrc>(src), UnaryFunction::Fabs };
557}
558
559/** Represents the expression: `log(\p src)`.
560 *
561 * @tparam TSrc The type of the argument.
562 * @param[in] src The argument.
563 * @return The resulting AST node.
564 */
565template <typename TSrc>
566UnaryElementwiseFunction<TSrc> log(TSrc &&src)
567{
568 return UnaryElementwiseFunction<TSrc>{ std::forward<TSrc>(src), UnaryFunction::Log };
569}
570
571/** Represents the expression: `round(\p src)`.
572 *
573 * @tparam TSrc The type of the argument.
574 * @param[in] src The argument.
575 * @return The resulting AST node.
576 */
577template <typename TSrc>
578UnaryElementwiseFunction<TSrc> round(TSrc &&src)
579{
580 return UnaryElementwiseFunction<TSrc>{ std::forward<TSrc>(src), UnaryFunction::Round };
581}
582
583/** Represents the expression: `sizeof(\p src)`.
584 *
585 * @tparam TSrc The type of the argument.
586 * @param[in] src The argument.
587 * @return The resulting AST node.
588 */
589template <typename TSrc>
590UnaryElementwiseFunction<TSrc> sizeOf(TSrc &&src)
591{
592 return UnaryElementwiseFunction<TSrc>{ std::forward<TSrc>(src), UnaryFunction::SizeOf };
593}
594
595// ==================================================
596// Binary elementwise functions
597// ==================================================
598
599/** AST node for binary elementwise functions.
600 *
601 * Note that both \p TFirst and \p TSecond must be operands.
602 *
603 * @tparam TFirst The type of the left argument of the function.
604 * @tparam TSecond The type of the right argument of the function.
605 */
606template <typename TFirst, typename TSecond, typename = ::std::enable_if<can_be_operand<TFirst>::value && can_be_operand<TSecond>::value>>
607struct BinaryElementwiseFunction
608{
609 TFirst first;
610 TSecond second;
611 BinaryFunction opcode;
612};
613
614template <typename TFirst, typename TSecond>
615struct can_be_operand<BinaryElementwiseFunction<TFirst, TSecond>> : ::std::true_type
616{
617};
618
619/** Represents the function call: `max(\p first, \p second)`.
620 *
621 * @tparam TFirst The type of the first argument.
622 * @tparam TSecond The type of the second argument.
623 * @param[in] first The first argument.
624 * @param[in] second The second argument.
625 * @return The resulting AST node.
626 */
627template <typename TFirst, typename TSecond>
628BinaryElementwiseFunction<TFirst, TSecond> max(TFirst &&first, TSecond &&second)
629{
630 return BinaryElementwiseFunction<TFirst, TSecond>{ std::forward<TFirst>(first), std::forward<TSecond>(second), BinaryFunction::Max };
631}
632
633/** Represents the function call: `min(\p first, \p second)`.
634 *
635 * @tparam TFirst The type of the first argument.
636 * @tparam TSecond The type of the second argument.
637 * @param[in] first The first argument.
638 * @param[in] second The second argument.
639 * @return The resulting AST node.
640 */
641template <typename TFirst, typename TSecond>
642BinaryElementwiseFunction<TFirst, TSecond> min(TFirst &&first, TSecond &&second)
643{
644 return BinaryElementwiseFunction<TFirst, TSecond>{ std::forward<TFirst>(first), std::forward<TSecond>(second), BinaryFunction::Min };
645}
646
647// ==================================================
648// Ternary elementwise functions
649// ==================================================
650
651/** AST node for ternary elementwise functions.
652 *
653 * Note that \p TFirst, \p TSecond, and \p TThird all must be operands.
654 *
655 * @tparam TFirst The type of the first argument to the function.
656 * @tparam TSecond The type of the second argument to the function.
657 * @tparam TThird The type of the third argument to the function.
658 */
659template <typename TFirst, typename TSecond, typename TThird, typename = ::std::enable_if<can_be_operand<TFirst>::value && can_be_operand<TSecond>::value && can_be_operand<TThird>::value>>
660struct TernaryElementwiseFunction
661{
662 TFirst first;
663 TSecond second;
664 TThird third;
665 TernaryFunction opcode;
666};
667
668template <typename TFirst, typename TSecond, typename TThird>
669struct can_be_operand<TernaryElementwiseFunction<TFirst, TSecond, TThird>> : ::std::true_type
670{
671};
672
673/** Represents the function call: `select(\p first, \p second, \p third)`.
674 *
675 * @tparam TFirst The type of the first argument.
676 * @tparam TSecond The type of the second argument.
677 * @tparam TThird The type of the third argument.
678 * @param[in] first The first argument.
679 * @param[in] second The second argument.
680 * @param[in] third The third argument.
681 * @return The resulting AST node.
682 */
683template <typename TFirst, typename TSecond, typename TThird>
684TernaryElementwiseFunction<TFirst, TSecond, TThird> select(TFirst &&first, TSecond &&second, TThird &&third)
685{
686 return TernaryElementwiseFunction<TFirst, TSecond, TThird>{ std::forward<TFirst>(first), std::forward<TSecond>(second), std::forward<TThird>(third), TernaryFunction::Select };
687}
688
689/** Helper class used to extend a KernelWriter with additional functionality
690 * in order to make writing easier.
691 *
692 * This extension automatically handles creation of temporary variables, and
693 * allows nested function calls and operations.
694 *
695 * @tparam TWriter The type of KernelWriter to be overloaded. This must inherit from KernelWriter.
696 */
697template <class TWriter, typename = std::enable_if<std::is_base_of<KernelWriter, TWriter>::value>>
698class KernelWriterHelper : public TWriter
699{
700public:
701 using TWriter::TWriter;
702
703 // ==================================================
704 // If-statements
705 // ==================================================
706
707 // Un-hide original implementation, in case the original implementation is required.
708 using TWriter::op_if;
709
710 /** Represents the if-statement: `if(\p cond) { \p body }`.
711 *
712 * The BinaryExpression is unpacked and its components are forwarded to
713 * the underlying KernelWriter's implementation.
714 *
715 * @param[in] cond The BinaryExpression representing the condition.
716 * @param[in] body The body of the if-statement.
717 */
718 KernelWriterHelper<TWriter> &op_if(const BinaryExpression<TileOperand &, TileOperand &> &cond, const std::function<void()> &body)
719 {
720 TWriter::op_if(cond.lhs, cond.opcode, cond.rhs, body);
721 return *this;
722 }
723
724 /** Represents the if-statement: `if(\p cond) { \p body }`.
725 *
726 * The BinaryExpression is unpacked and its components are forwarded to
727 * the underlying KernelWriter's implementation.
728 *
729 * @param[in] cond The BinaryExpression representing the condition.
730 * @param[in] body The body of the if-statement.
731 */
732 template <typename TRight>
733 KernelWriterHelper<TWriter> &op_if(const BinaryExpression<TileOperand &, TRight> &cond, const std::function<void()> &body)
734 {
735 auto &tmp1 = declare_temp_tile(cond.lhs.tile_info());
736 op_assign(tmp1, cond.rhs);
737 TWriter::op_if(cond.lhs, cond.opcode, tmp1, body);
738 return *this;
739 }
740
741 /** Represents the if-statement: `if(\p cond) { \p body }`.
742 *
743 * The BinaryExpression is unpacked and its components are forwarded to
744 * the underlying KernelWriter's implementation.
745 *
746 * @param[in] cond The BinaryExpression representing the condition.
747 * @param[in] body The body of the if-statement.
748 */
749 template <typename TLeft>
750 KernelWriterHelper<TWriter> &op_if(const BinaryExpression<TLeft, TileOperand &> &cond, const std::function<void()> &body)
751 {
752 auto &tmp1 = declare_temp_tile(cond.rhs.tile_info());
753 op_assign(tmp1, cond.lhs);
754 TWriter::op_if(tmp1, cond.opcode, cond.rhs, body);
755 return *this;
756 }
757
758 // Un-hide original implementation, in case the original implementation is required.
759 using TWriter::op_else_if;
760
761 /** Represents the else-if-statement: `else if(\p cond) { \p body }`.
762 *
763 * The BinaryExpression is unpacked and its components are forwarded to
764 * the underlying KernelWriter's implementation.
765 *
766 * @param[in] cond The BinaryExpression representing the condition.
767 * @param[in] body The body of the else-if-statement.
768 */
769 KernelWriterHelper<TWriter> &op_else_if(const BinaryExpression<TileOperand &, TileOperand &> &cond, const std::function<void()> &body)
770 {
771 TWriter::op_else_if(cond.lhs, cond.opcode, cond.rhs, body);
772 return *this;
773 }
774
775 /** Represents the else-if-statement: `else if(\p cond) { \p body }`.
776 *
777 * The BinaryExpression is unpacked and its components are forwarded to
778 * the underlying KernelWriter's implementation.
779 *
780 * @param[in] cond The BinaryExpression representing the condition.
781 * @param[in] body The body of the else-if-statement.
782 */
783 template <typename TRight>
784 KernelWriterHelper<TWriter> &op_else_if(const BinaryExpression<TileOperand &, TRight> &cond, const std::function<void()> &body)
785 {
786 auto &tmp1 = declare_temp_tile(cond.lhs.tile_info());
787 op_assign(tmp1, cond.rhs);
788 TWriter::op_else_if(cond.lhs, cond.opcode, tmp1, body);
789 return *this;
790 }
791
792 /** Represents the else-if-statement: `else if(\p cond) { \p body }`.
793 *
794 * The BinaryExpression is unpacked and its components are forwarded to
795 * the underlying KernelWriter's implementation.
796 *
797 * @param[in] cond The BinaryExpression representing the condition.
798 * @param[in] body The body of the else-if-statement.
799 */
800 template <typename TLeft>
801 KernelWriterHelper<TWriter> &op_else_if(const BinaryExpression<TLeft, TileOperand &> &cond, const std::function<void()> &body)
802 {
803 auto &tmp1 = declare_temp_tile(cond.rhs.tile_info());
804 op_assign(tmp1, cond.lhs);
805 TWriter::op_else_if(tmp1, cond.opcode, cond.rhs, body);
806 return *this;
807 }
808
809 // ==================================================
810 // For-loops
811 // ==================================================
812
813 // Un-hide original implementation, in case the original implementation is required.
814 using TWriter::op_for_loop;
815
816 /** Represents the for-loop: `for(;\p cond; \p updater) { \p body }`.
817 *
818 * The BinaryExpression for the condition and the Assignment
819 * for the updater are unpacked and their components are forwarded to
820 * the underlying KernelWriter's implementation.
821 *
822 * @param[in] cond The BinaryExpression representing the condition.
823 * @param[in] updater The Assignment representing the updater.
824 * @param[in] body The body of the for-loop.
825 */
826 void op_for_loop(const BinaryExpression<TileOperand &, TileOperand &> &cond, const Assignment<TileOperand &, TileOperand &> &updater, const std::function<void()> &body)
827 {
828 TWriter::op_for_loop(cond.lhs, cond.opcode, cond.rhs, updater.lhs, updater.opcode, updater.rhs, body);
829 }
830
831 // ==================================================
832 // Unary expressions
833 // ==================================================
834
835 // Un-hide original implementation, in case the original implementation is required.
836 using TWriter::op_assign;
837
838 /** Represents the assignment: `\p dst = \p exp`.
839 *
840 * The UnaryExpression is unpacked and its components are forwarded to
841 * the underlying KernelWriter's implementation.
842 *
843 * @param[in] dst The tile which is assigned to.
844 * @param[in] exp The UnaryExpression representing the expression to be evaluated and assigned.
845 */
846 void op_assign(const TileOperand &dst, const UnaryExpression<TileOperand &> &exp)
847 {
848 TWriter::op_unary_expression(dst, exp.opcode, exp.src);
849 }
850
851 // ==================================================
852 // Binary expressions
853 // ==================================================
854
855 /** Represents the assignment: `\p dst = \p exp`.
856 *
857 * The BinaryExpression is unpacked and its components are forwarded to
858 * the underlying KernelWriter's implementation.
859 *
860 * @param[in] dst The tile which is assigned to.
861 * @param[in] exp The BinaryExpression representing the expression to be evaluated and assigned.
862 */
863 void op_assign(const TileOperand &dst, const BinaryExpression<TileOperand &, TileOperand &> &exp)
864 {
865 TWriter::op_binary_expression(dst, exp.lhs, exp.opcode, exp.rhs);
866 }
867
868 /** Represents the assignment: `\p dst = \p exp`.
869 *
870 * The BinaryExpression is unpacked and its components are forwarded to
871 * the underlying KernelWriter's implementation.
872 *
873 * @param[in] dst The tile which is assigned to.
874 * @param[in] exp The BinaryExpression representing the expression to be evaluated and assigned.
875 */
876 template <typename TRight>
877 void op_assign(const TileOperand &dst, const BinaryExpression<TileOperand &, TRight> &exp)
878 {
879 std::cout << "Beginning assignment!" << std::endl;
880 auto &tmp1 = declare_temp_tile(dst.tile_info());
881 op_assign(tmp1, exp.rhs);
882 TWriter::op_binary_expression(dst, exp.lhs, exp.opcode, tmp1);
883 }
884
885 /** Represents the assignment: `\p dst = \p exp`.
886 *
887 * The BinaryExpression is unpacked and its components are forwarded to
888 * the underlying KernelWriter's implementation.
889 *
890 * @param[in] dst The tile which is assigned to.
891 * @param[in] exp The BinaryExpression representing the expression to be evaluated and assigned.
892 */
893 template <typename TLeft>
894 void op_assign(const TileOperand &dst, const BinaryExpression<TLeft, TileOperand &> &exp)
895 {
896 std::cout << "Beginning assignment!" << std::endl;
897 auto &tmp1 = declare_temp_tile(dst.tile_info());
898 op_assign(tmp1, exp.lhs);
899 TWriter::op_binary_expression(dst, tmp1, exp.opcode, exp.rhs);
900 }
901
902 /** Represents the assignment: `\p dst = \p exp`.
903 *
904 * The BinaryExpression is unpacked and its components are forwarded to
905 * the underlying KernelWriter's implementation.
906 *
907 * @param[in] dst The tile which is assigned to.
908 * @param[in] exp The BinaryExpression representing the expression to be evaluated and assigned.
909 */
910 template <typename TLeft, typename TRight>
911 void op_assign(const TileOperand &dst, const BinaryExpression<TLeft, TRight> &exp)
912 {
913 auto &tmp1 = declare_temp_tile(dst.tile_info());
914 auto &tmp2 = declare_temp_tile(dst.tile_info());
915 op_assign(tmp1, exp.lhs);
916 op_assign(tmp2, exp.rhs);
917 TWriter::op_binary_expression(dst, tmp1, exp.opcode, tmp2);
918 }
919
920 // ==================================================
921 // Unary elementwise functions
922 // ==================================================
923
924 /** Represents the assignment: `\p dst = \p exp`.
925 *
926 * The UnaryElementwiseFunction is unpacked and its components are forwarded to
927 * the underlying KernelWriter's implementation.
928 *
929 * @param[in] dst The tile which is assigned to.
930 * @param[in] exp The UnaryElementwiseFunction representing the expression to be evaluated and assigned.
931 */
932 void op_assign(const TileOperand &dst, const UnaryElementwiseFunction<TileOperand &> &exp)
933 {
934 TWriter::op_unary_elementwise_function(dst, exp.opcode, exp.src);
935 }
936
937 /** Represents the assignment: `\p dst = \p exp`.
938 *
939 * The UnaryElementwiseFunction is unpacked and its components are forwarded to
940 * the underlying KernelWriter's implementation.
941 *
942 * @param[in] dst The tile which is assigned to.
943 * @param[in] exp The UnaryElementwiseFunction representing the expression to be evaluated and assigned.
944 */
945 template <typename TArg>
946 void op_assign(const TileOperand &dst, const UnaryElementwiseFunction<TArg> &exp)
947 {
948 auto &tmp1 = declare_temp_tile(dst.tile_info());
949 op_assign(tmp1, exp.lhs);
950 TWriter::op_unary_elementwise_function(dst, exp.opcode, tmp1);
951 }
952
953 // ==================================================
954 // Binary elementwise functions
955 // ==================================================
956
957 /** Represents the assignment: `\p dst = \p exp`.
958 *
959 * The BinaryElementwiseFunction is unpacked and its components are forwarded to
960 * the underlying KernelWriter's implementation.
961 *
962 * @param[in] dst The tile which is assigned to.
963 * @param[in] exp The BinaryElementwiseFunction representing the expression to be evaluated and assigned.
964 */
965 void op_assign(const TileOperand &dst, const BinaryElementwiseFunction<TileOperand &, TileOperand &> &exp)
966 {
967 TWriter::op_binary_elementwise_function(dst, exp.opcode, exp.first, exp.second);
968 }
969
970 /** Represents the assignment: `\p dst = \p exp`.
971 *
972 * The BinaryElementwiseFunction is unpacked and its components are forwarded to
973 * the underlying KernelWriter's implementation.
974 *
975 * @param[in] dst The tile which is assigned to.
976 * @param[in] exp The BinaryElementwiseFunction representing the expression to be evaluated and assigned.
977 */
978 template <typename TRight>
979 void op_assign(const TileOperand &dst, const BinaryElementwiseFunction<TileOperand &, TRight> &exp)
980 {
981 auto &tmp1 = declare_temp_tile(dst.tile_info());
982 op_assign(tmp1, exp.second);
983 TWriter::op_binary_elementwise_function(dst, exp.opcode, exp.first, tmp1);
984 }
985
986 /** Represents the assignment: `\p dst = \p exp`.
987 *
988 * The BinaryElementwiseFunction is unpacked and its components are forwarded to
989 * the underlying KernelWriter's implementation.
990 *
991 * @param[in] dst The tile which is assigned to.
992 * @param[in] exp The BinaryElementwiseFunction representing the expression to be evaluated and assigned.
993 */
994 template <typename TLeft>
995 void op_assign(const TileOperand &dst, const BinaryElementwiseFunction<TLeft, TileOperand &> &exp)
996 {
997 auto &tmp1 = declare_temp_tile(dst.tile_info());
998 op_assign(tmp1, exp.first);
999 TWriter::op_binary_elementwise_function(dst, exp.opcode, tmp1, exp.second);
1000 }
1001
1002 /** Represents the assignment: `\p dst = \p exp`.
1003 *
1004 * The BinaryElementwiseFunction is unpacked and its components are forwarded to
1005 * the underlying KernelWriter's implementation.
1006 *
1007 * @param[in] dst The tile which is assigned to.
1008 * @param[in] exp The BinaryElementwiseFunction representing the expression to be evaluated and assigned.
1009 */
1010 template <typename TLeft, typename TRight>
1011 void op_assign(const TileOperand &dst, const BinaryElementwiseFunction<TLeft, TRight> &exp)
1012 {
1013 auto &tmp1 = declare_temp_tile(dst.tile_info());
1014 auto &tmp2 = declare_temp_tile(dst.tile_info());
1015 op_assign(tmp1, exp.first);
1016 op_assign(tmp2, exp.second);
1017 TWriter::op_binary_elementwise_function(dst, exp.opcode, tmp1, tmp2);
1018 }
1019
1020 // ==================================================
1021 // Ternary elementwise functions
1022 // ==================================================
1023
1024 /** Represents the assignment: `\p dst = \p exp`.
1025 *
1026 * The TernaryElementwiseFunction is unpacked and its components are forwarded to
1027 * the underlying KernelWriter's implementation.
1028 *
1029 * @param[in] dst The tile which is assigned to.
1030 * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned.
1031 */
1032 void op_assign(const TileOperand &dst, const TernaryElementwiseFunction<TileOperand &, TileOperand &, TileOperand &> &exp)
1033 {
1034 TWriter::op_ternary_elementwise_function(dst, exp.opcode, exp.first, exp.second, exp.third);
1035 }
1036
1037 /** Represents the assignment: `\p dst = \p exp`.
1038 *
1039 * The TernaryElementwiseFunction is unpacked and its components are forwarded to
1040 * the underlying KernelWriter's implementation.
1041 *
1042 * @param[in] dst The tile which is assigned to.
1043 * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned.
1044 */
1045 template <typename TFirst>
1046 void op_assign(const TileOperand &dst, const TernaryElementwiseFunction<TFirst, TileOperand &, TileOperand &> &exp)
1047 {
1048 auto &tmp1 = declare_temp_tile(dst.tile_info());
1049 op_assign(tmp1, exp.first);
1050 TWriter::op_ternary_elementwise_function(dst, exp.opcode, tmp1, exp.second, exp.third);
1051 }
1052
1053 /** Represents the assignment: `\p dst = \p exp`.
1054 *
1055 * The TernaryElementwiseFunction is unpacked and its components are forwarded to
1056 * the underlying KernelWriter's implementation.
1057 *
1058 * @param[in] dst The tile which is assigned to.
1059 * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned.
1060 */
1061 template <typename TSecond>
1062 void op_assign(const TileOperand &dst, const TernaryElementwiseFunction<TileOperand &, TSecond, TileOperand &> &exp)
1063 {
1064 auto &tmp1 = declare_temp_tile(dst.tile_info());
1065 op_assign(tmp1, exp.second);
1066 TWriter::op_ternary_elementwise_function(dst, exp.opcode, exp.first, tmp1, exp.third);
1067 }
1068
1069 /** Represents the assignment: `\p dst = \p exp`.
1070 *
1071 * The TernaryElementwiseFunction is unpacked and its components are forwarded to
1072 * the underlying KernelWriter's implementation.
1073 *
1074 * @param[in] dst The tile which is assigned to.
1075 * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned.
1076 */
1077 template <typename TThird>
1078 void op_assign(const TileOperand &dst, const TernaryElementwiseFunction<TileOperand &, TileOperand &, TThird> &exp)
1079 {
1080 auto &tmp1 = declare_temp_tile(dst.tile_info());
1081 op_assign(tmp1, exp.third);
1082 TWriter::op_ternary_elementwise_function(dst, exp.opcode, exp.first, exp.second, tmp1);
1083 }
1084
1085 /** Represents the assignment: `\p dst = \p exp`.
1086 *
1087 * The TernaryElementwiseFunction is unpacked and its components are forwarded to
1088 * the underlying KernelWriter's implementation.
1089 *
1090 * @param[in] dst The tile which is assigned to.
1091 * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned.
1092 */
1093 template <typename TFirst, typename TSecond>
1094 void op_assign(const TileOperand &dst, const TernaryElementwiseFunction<TFirst, TSecond, TileOperand &> &exp)
1095 {
1096 auto &tmp1 = declare_temp_tile(dst.tile_info());
1097 auto &tmp2 = declare_temp_tile(dst.tile_info());
1098 op_assign(tmp1, exp.first);
1099 op_assign(tmp2, exp.second);
1100 TWriter::op_ternary_elementwise_function(dst, exp.opcode, tmp1, tmp2, exp.third);
1101 }
1102
1103 /** Represents the assignment: `\p dst = \p exp`.
1104 *
1105 * The TernaryElementwiseFunction is unpacked and its components are forwarded to
1106 * the underlying KernelWriter's implementation.
1107 *
1108 * @param[in] dst The tile which is assigned to.
1109 * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned.
1110 */
1111 template <typename TFirst, typename TThird>
1112 void op_assign(const TileOperand &dst, const TernaryElementwiseFunction<TFirst, TileOperand &, TThird> &exp)
1113 {
1114 auto &tmp1 = declare_temp_tile(dst.tile_info());
1115 auto &tmp2 = declare_temp_tile(dst.tile_info());
1116 op_assign(tmp1, exp.first);
1117 op_assign(tmp2, exp.third);
1118 TWriter::op_ternary_elementwise_function(dst, exp.opcode, tmp1, exp.second, tmp2);
1119 }
1120
1121 /** Represents the assignment: `\p dst = \p exp`.
1122 *
1123 * The TernaryElementwiseFunction is unpacked and its components are forwarded to
1124 * the underlying KernelWriter's implementation.
1125 *
1126 * @param[in] dst The tile which is assigned to.
1127 * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned.
1128 */
1129 template <typename TSecond, typename TThird>
1130 void op_assign(const TileOperand &dst, const TernaryElementwiseFunction<TileOperand &, TSecond, TThird> &exp)
1131 {
1132 auto &tmp1 = declare_temp_tile(dst.tile_info());
1133 auto &tmp2 = declare_temp_tile(dst.tile_info());
1134 op_assign(tmp1, exp.second);
1135 op_assign(tmp2, exp.third);
1136 TWriter::op_ternary_elementwise_function(dst, exp.opcode, exp.first, tmp1, tmp2);
1137 }
1138
1139 /** Represents the assignment: `\p dst = \p exp`.
1140 *
1141 * The TernaryElementwiseFunction is unpacked and its components are forwarded to
1142 * the underlying KernelWriter's implementation.
1143 *
1144 * @param[in] dst The tile which is assigned to.
1145 * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned.
1146 */
1147 template <typename TFirst, typename TSecond, typename TThird>
1148 void op_assign(const TileOperand &dst, const TernaryElementwiseFunction<TFirst, TSecond, TThird> &exp)
1149 {
1150 auto &tmp1 = declare_temp_tile(dst.tile_info(), dst.tile_info(), dst.tile_info());
1151 auto &tmp2 = declare_temp_tile(dst.tile_info());
1152 auto &tmp3 = declare_temp_tile(dst.tile_info());
1153 op_assign(tmp1, exp.first);
1154 op_assign(tmp2, exp.second);
1155 op_assign(tmp3, exp.third);
1156 TWriter::op_ternary_elementwise_function(dst, exp.opcode, tmp1, tmp2, tmp3);
1157 }
1158
1159 // ==================================================
1160 // Assignments
1161 // ==================================================
1162
1163 /** Represents the assignment: `\p lhs += \p rhs` or `\p lhs -= \p rhs`.
1164 *
1165 * The Assignment is unpacked and its components are forwarded to
1166 * the underlying KernelWriter's implementation.
1167 *
1168 * @param[in] exp The Assignment representing the expression to be evaluated.
1169 */
1170 void op_assign(const Assignment<TileOperand &, TileOperand &> &exp)
1171 {
1172 if(exp.opcode == AssignmentOp::Increment)
1173 {
1174 TWriter::op_binary_expression(exp.lhs, exp.lhs, BinaryOp::Add, exp.rhs);
1175 }
1176 else if(exp.opcode == AssignmentOp::Decrement)
1177 {
1178 TWriter::op_binary_expression(exp.lhs, exp.lhs, BinaryOp::Sub, exp.rhs);
1179 }
1180 }
1181
1182 /** Represents the assignment: `\p lhs += \p rhs` or `\p lhs -= \p rhs`.
1183 *
1184 * The Assignment is unpacked and its components are forwarded to
1185 * the underlying KernelWriter's implementation.
1186 *
1187 * @tparam TRight The type of the RHS of the assignment.
1188 * @param[in] exp The Assignment representing the expression to be evaluated.
1189 */
1190 template <typename TRight>
1191 void op_assign(const Assignment<TileOperand &, TRight> &exp)
1192 {
1193 auto &tmp1 = declare_temp_tile(exp.lhs.tile_info());
1194 op_assign(tmp1, exp.rhs);
1195 op_assign(Assignment<TileOperand &, TileOperand &>{ exp.lhs, tmp1, exp.opcode });
1196 }
1197
1198private:
1199 unsigned int temp_var_counter = 0;
1200
1201 /** Return the current counter value, then increment it.
1202 *
1203 * @return The current counter value.
1204 */
1205 int next_ctr()
1206 {
1207 return temp_var_counter++;
1208 }
1209
1210 /** Gets the next temporary variable counter value,
1211 * and returns a suitable temporary variable name.
1212 *
1213 * @return A temporary variable name.
1214 */
1215 std::string next_tmp_var_name()
1216 {
1217 return "tmp_" + std::to_string(next_ctr());
1218 }
1219
1220 /** Returns the argument.
1221 *
1222 * Used for recursion with the variadic function version of this function.
1223 *
1224 * @param[in] arg The TileInfo to return.
1225 * @return The \p arg.
1226 */
1227 TileInfo get_largest_size(const TileInfo &arg)
1228 {
1229 return arg;
1230 }
1231
1232 /** Returns a TileInfo object where the size in each dimension (width, height) is the largest
1233 * of either TileInfo argument in the corresponding dimension.
1234 *
1235 * @tparam TOps Must be of TileInfo type.
1236 * @param[in] first A TileInfo object.
1237 * @param[in] second A TileInfo object.
1238 * @param[in] ops A number of TileInfo objects.
1239 * @return A TileInfo object which represents the largest shape in each dimension across the arguments.
1240 */
1241 template <typename... TOps, typename = ::std::enable_if_t<std::is_same<TOps..., TileInfo>::value>>
1242 TileInfo get_largest_size(const TileInfo &first, const TileInfo &second, const TOps &...ops)
1243 {
1244 TileInfo largest = {
1245 first.data_type(),
1246 std::max(first.width(), second.width()),
1247 std::max(first.height(), second.height())
1248 };
1249 return get_largest_size(largest, ops...);
1250 }
1251
1252 /** Helper function to define a suitable TileOperand with appropriate TileInfo
1253 * such that broadcasting is taken into account, based on the arguments provided.
1254 *
1255 * @tparam TArgs Must be of TileInfo type.
1256 * @param[in] args A number of TileInfo which determine the shape of the TileOperand to declare.
1257 * @return A newly created TileOperand.
1258 */
1259 template <typename... TArgs, typename = ::std::enable_if_t<std::is_same<TArgs..., TileInfo>::value>>
1260 TileOperand &declare_temp_tile(const TArgs &...args)
1261 {
1262 return TWriter::declare_tile(next_tmp_var_name().c_str(), get_largest_size(args...));
1263 }
1264};
1265
1266} // namespace ckw
1267
1268#endif // CKW_INCLUDE_CKW_KERNELWRITERHELPER_H