blob: 9458ced9169befadbdaf4a2b53361af8bbb240fd [file] [log] [blame]
Viet-Hoa Do3389f532023-07-05 17:36:40 +01001/*
2 * Copyright (c) 2023 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#ifndef CKW_SRC_CL_CLKERNELWRITER_H
26#define CKW_SRC_CL_CLKERNELWRITER_H
27
28#include "ckw/KernelWriter.h"
Gunes Bayir3c776062023-07-12 14:50:56 +010029
Viet-Hoa Do25d26f42023-07-20 17:31:47 +010030#include <memory>
31#include <set>
Gunes Bayir806b8e82023-08-23 23:28:31 +010032#include <string>
Viet-Hoa Do3389f532023-07-05 17:36:40 +010033#include <utility>
34
35namespace ckw
36{
37
Gunes Bayir806b8e82023-08-23 23:28:31 +010038// Forward Declarations
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +010039class CLTile;
40class CLTensorArgument;
Gunes Bayir806b8e82023-08-23 23:28:31 +010041class ConstantData;
42class TensorOperand;
Gunes Bayir47a396e2023-08-17 11:04:02 +010043class TensorSampler;
44class TileOperand;
Gunes Bayir47a396e2023-08-17 11:04:02 +010045
Gunes Bayir806b8e82023-08-23 23:28:31 +010046enum class DataType;
Gunes Bayir47a396e2023-08-17 11:04:02 +010047enum class MemoryOperation;
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +010048
Viet-Hoa Do3389f532023-07-05 17:36:40 +010049/** OpenCL kernel writer. */
50class CLKernelWriter : public KernelWriter
51{
52public:
53 // =============================================================================================
54 // Construtors and destructor
55 // =============================================================================================
56
57 /** Initialize a new instance of @ref CLKernelWriter class. */
58 CLKernelWriter();
59
60 /** Destructor */
61 ~CLKernelWriter();
62
63 // =============================================================================================
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +010064 // Data processing
65 // =============================================================================================
66
67 void op_assign(const TileOperand &dst, const TileOperand &src) override;
68
69 void op_cast(const TileOperand &dst, const TileOperand &src, ConvertPolicy policy) override;
70
Viet-Hoa Do34b6c3a2023-08-22 11:11:23 +010071 void op_unary(const TileOperand &dst, UnaryOp op, const TileOperand &src) override;
72
73 void op_binary(const TileOperand &dst, BinaryOp op, const TileOperand &first, const TileOperand &second) override;
74
75 void op_ternary(const TileOperand &dst, TernaryOp op, const TileOperand &first, const TileOperand &second, const TileOperand &third) override;
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +010076
77 // =============================================================================================
Viet-Hoa Do2d0c2f52023-08-24 11:48:19 +010078 // Flow control
79 // =============================================================================================
80
81 void op_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body) override;
82
83 void op_else_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body) override;
84
85 void op_else(const std::function<void()> &body) override;
86
87 void op_for_loop(
88 const TileOperand &var, BinaryOp cond_op, const TileOperand &cond_value,
89 const TileOperand &update_var, AssignmentOp update_op, const TileOperand &update_value,
90 const std::function<void()> &body) override;
91
92 void op_return() override;
93
94 // =============================================================================================
Viet-Hoa Do3389f532023-07-05 17:36:40 +010095 // Misc
96 // =============================================================================================
97
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +010098 void op_comment(const std::string &text) override;
Viet-Hoa Do3389f532023-07-05 17:36:40 +010099
Gunes Bayir366514d2023-07-27 22:52:32 +0100100 void op_write_raw_code(const std::string &raw_code) override;
101
Viet-Hoa Do3389f532023-07-05 17:36:40 +0100102 // =============================================================================================
103 // Code generation
104 // =============================================================================================
105
106 std::unique_ptr<Kernel> emit_kernel(const std::string &name) override;
107
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +0100108 // =============================================================================================
109 // Tensor and tile declaration
110 // =============================================================================================
111
112 TensorOperand declare_tensor_argument(const std::string &name, const TensorInfo &info) override;
113
Gunes Bayirab0b7502023-07-11 14:57:36 +0100114 /** Declare a tile given name and tile information
115 *
116 * Similar to @ref KernelWriter::declare_tile()
Gunes Bayir47a396e2023-08-17 11:04:02 +0100117 */
Viet-Hoa Do25d26f42023-07-20 17:31:47 +0100118 TileOperand declare_tile(const std::string &name, const TileInfo &tile_info) override;
Gunes Bayirab0b7502023-07-11 14:57:36 +0100119
Gunes Bayir806b8e82023-08-23 23:28:31 +0100120 /** Declare a constant tile given a @ref:ConstantData object
121 *
122 * Similar to @ref KernelWriter::declare_constant_tile()
123 */
124 TileOperand declare_constant_tile(const ConstantData &data) override;
125
Gunes Bayir47a396e2023-08-17 11:04:02 +0100126 // =============================================================================================
127 // Memory Operations
128 // =============================================================================================
129
130 /** Load the data from the tensor memory to the tile using the sampling information.
131 *
132 * Similar to @ref KernelWriter::op_load()
133 */
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +0100134 void op_load(
135 const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler,
Gunes Bayir47a396e2023-08-17 11:04:02 +0100136 const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch) override;
137
138 /** Load the data from the tensor memory to the tile in a dilated way using the sampling information.
139 *
140 * Similar to @ref KernelWriter::op_load_dilated()
141 */
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +0100142 void op_load_dilated(
143 const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler,
Gunes Bayir47a396e2023-08-17 11:04:02 +0100144 const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch,
145 const TileOperand &dilation_x, const TileOperand &dilation_y) override;
146
147 /** Store the data to the tensor memory from the tile using the sampling information.
148 *
149 * Similar to @ref KernelWriter::op_store()
150 */
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +0100151 void op_store(
152 const TensorOperand &tensor_op, const TileOperand &tile_op, TensorSampler &sampler,
Gunes Bayir47a396e2023-08-17 11:04:02 +0100153 const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch) override;
154
155 /** Store the data to the tensor memory from the tile in a dilated way using the sampling information.
156 *
157 * Similar to @ref KernelWriter::op_store_dilated()
158 */
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +0100159 void op_store_dilated(
160 const TensorOperand &tensor_op, const TileOperand &tile_op, TensorSampler &sampler,
Gunes Bayir47a396e2023-08-17 11:04:02 +0100161 const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch,
162 const TileOperand &dilation_x, const TileOperand &dilation_y) override;
163
Viet-Hoa Do3389f532023-07-05 17:36:40 +0100164protected:
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +0100165 /** Return @ref CLTile object from the @ref TileOperand object.
166 *
167 * This function performs appropriate check before doing type casting.
168 */
Gunes Bayir806b8e82023-08-23 23:28:31 +0100169 const CLTile &to_cl_tile(const TileOperand &operand) const;
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +0100170
Viet-Hoa Do3389f532023-07-05 17:36:40 +0100171 /** Append the specified code to the kernel body source code. */
172 template <typename T, typename... TArgs>
173 void append_code(T &&code, TArgs &&...args)
174 {
175 append_code(std::forward<T>(code));
176 append_code(std::forward<TArgs>(args)...);
177 }
178
179 /** Append the specified code to the kernel body source code. */
180 template <typename T>
181 void append_code(T &&code)
182 {
183 _body_source_code += std::forward<T>(code);
184 }
185
186 /** Get the current kernel body source code. */
187 const std::string &body_source_code() const;
188
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +0100189 // For helper functions
Gunes Bayir47a396e2023-08-17 11:04:02 +0100190private:
Gunes Bayir47a396e2023-08-17 11:04:02 +0100191 /** Helper function to consolidate all load/store logic in this class */
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +0100192 void op_load_store(
193 MemoryOperation op, const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler,
Gunes Bayir47a396e2023-08-17 11:04:02 +0100194 const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch,
195 const CLTile &dilation_x, const CLTile &dilation_y);
196
Viet-Hoa Do2d0c2f52023-08-24 11:48:19 +0100197 /** This function is the generic function to write both `if` and `else if` blocks.
198 *
199 * It is used for both @ref CLKernelWriter::op_if and @ref CLKernelWriter::op_else_if.
200 *
Viet-Hoa Do98901e42023-08-30 10:12:22 +0100201 * @param[in] lhs The LHS tile of the condition.
202 * @param[in] op The relational binary operator.
203 * @param[in] rhs The RHS tile of the condition.
204 * @param[in] body The function that writes the body of the else-if block.
205 * @param[in] is_else_if True if this is an `else if` block, otherwise this is an `if` block.
Viet-Hoa Do2d0c2f52023-08-24 11:48:19 +0100206 */
Viet-Hoa Do98901e42023-08-30 10:12:22 +0100207 void op_if_generic(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body, bool is_else_if);
Viet-Hoa Do2d0c2f52023-08-24 11:48:19 +0100208
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +0100209 // For attributes
Viet-Hoa Do3389f532023-07-05 17:36:40 +0100210private:
211 /** This string contains the kernel body source code, not the full CL source code.
212 * The full source code will only be generated when the user calls @ref KernelWriter::emit_kernel.
213 *
214 * In order to add code to this, use @ref CLKernelWriter::append_code.
215 * Do not attempt to concatenate and alter this string directly.
216 */
217 std::string _body_source_code{};
Viet-Hoa Do25d26f42023-07-20 17:31:47 +0100218
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +0100219 std::set<std::unique_ptr<CLTensorArgument>> _tensors{};
220 std::set<std::unique_ptr<CLTile>> _tiles{};
Gunes Bayir806b8e82023-08-23 23:28:31 +0100221 std::set<std::unique_ptr<CLTile>> _constant_tiles{};
Viet-Hoa Do3389f532023-07-05 17:36:40 +0100222};
223
224} // namespace ckw
225
226#endif // CKW_SRC_CL_CLKERNELWRITER_H