blob: 022ae83999d94ff1bbc12782e5c8efb8018263b2 [file] [log] [blame]
Viet-Hoa Do3389f532023-07-05 17:36:40 +01001/*
2 * Copyright (c) 2023 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#ifndef CKW_INCLUDE_CKW_KERNELWRITER_H
26#define CKW_INCLUDE_CKW_KERNELWRITER_H
27
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +010028#include "ckw/TensorOperand.h"
Gunes Bayir3c776062023-07-12 14:50:56 +010029#include "ckw/TileOperand.h"
Gunes Bayir806b8e82023-08-23 23:28:31 +010030#include "ckw/types/ConstantData.h"
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +010031#include "ckw/types/ConvertPolicy.h"
32#include "ckw/types/Operators.h"
Gunes Bayirab0b7502023-07-11 14:57:36 +010033
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +010034#include <functional>
Viet-Hoa Do3389f532023-07-05 17:36:40 +010035#include <memory>
36#include <string>
37
38namespace ckw
39{
40
Gunes Bayirab0b7502023-07-11 14:57:36 +010041/** Forward Declerations */
Gunes Bayir806b8e82023-08-23 23:28:31 +010042class Kernel;
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +010043class TensorInfo;
Gunes Bayir47a396e2023-08-17 11:04:02 +010044class TensorSampler;
Gunes Bayirab0b7502023-07-11 14:57:36 +010045class TileInfo;
Gunes Bayir47a396e2023-08-17 11:04:02 +010046
Gunes Bayir806b8e82023-08-23 23:28:31 +010047enum class DataType;
Gunes Bayirab0b7502023-07-11 14:57:36 +010048enum class TargetArchitecture;
49enum class TargetLanguage;
50
Viet-Hoa Do3389f532023-07-05 17:36:40 +010051/** A kernel writer.
52 *
53 * This class is used to construct a new kernel by defining arguments, declaring variable and writing code.
54 *
55 * Use @ref KernelWriter::create_instance method to create the kernel writer for the specific target architecture and language.
56 *
57 * After having finished constructing the kernel, call @ref KernelWriter::emit_kernel to get the kernel object.
58 */
59class KernelWriter
60{
61public:
62 // =============================================================================================
63 // Construtors and destructor
64 // =============================================================================================
65
66 /** Initialize a new instance of @ref KernelWriter class for the specific architecture and language.
67 *
68 * Supported target architectures and languages:
69 *
70 * Architecture | Languages |
71 * ------------------------------|------------------------------|
72 * GpuArmMaliValhall | OpenCL |
73 *
74 * @param[in] architecture The architecture on which the kernel is executed.
75 * @param[in] language The language to write the kernel.
76 */
77 static std::unique_ptr<KernelWriter> create_instance(TargetArchitecture architecture, TargetLanguage language);
78
79 /** Destructor */
80 virtual ~KernelWriter();
81
82 // =============================================================================================
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +010083 // Data processing
84 // =============================================================================================
85
86 /** Write assignment statement: `<dst> = <src>;`.
87 *
88 * @param[in] dst The destination tile.
89 * @param[in] src The source tile.
90 */
91 virtual void op_assign(const TileOperand &dst, const TileOperand &src) = 0;
92
93 /** Write the cast statement: `<dst> = convert_<dst.type><policy>(<src>);`.
94 *
95 * @param[in] dst The destination tile.
96 * @param[in] src The source tile.
97 * @param[in] policy The policy governing the behavior of the cast.
98 */
99 virtual void op_cast(const TileOperand &dst, const TileOperand &src, ConvertPolicy policy) = 0;
100
101 /** Write the unary expression statement: `<dst> = <op> <src>;`.
102 *
103 * @param[in] dst The destination tile.
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +0100104 * @param[in] op The unary operator.
Viet-Hoa Do34b6c3a2023-08-22 11:11:23 +0100105 * @param[in] src The source tile.
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +0100106 */
Viet-Hoa Do34b6c3a2023-08-22 11:11:23 +0100107 virtual void op_unary(const TileOperand &dst, UnaryOp op, const TileOperand &src) = 0;
108
109 /** Write the binary expression statement: `<dst> = <op>(<first>, <second>);`.
110 *
111 * @param[in] dst The destination tile.
112 * @param[in] op The binary operator.
113 * @param[in] first The first source tile.
114 * @param[in] second The second source tile.
115 */
116 virtual void op_binary(const TileOperand &dst, BinaryOp op, const TileOperand &first, const TileOperand &second) = 0;
117
118 /** Write ternary expression statement: `<dst> = <op>(<first>, <second>, <third>);`.
119 *
120 * @param[in] dst The destination tile.
121 * @param[in] op The ternary operator.
122 * @param[in] first The first source tile.
123 * @param[in] second The second source tile.
124 * @param[in] third The third source tile.
125 */
126 virtual void op_ternary(const TileOperand &dst, TernaryOp op, const TileOperand &first, const TileOperand &second, const TileOperand &third) = 0;
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +0100127
128 // =============================================================================================
Viet-Hoa Do3389f532023-07-05 17:36:40 +0100129 // Misc
130 // =============================================================================================
131
132 /** Write the line comment in debug build.
Gunes Bayirab0b7502023-07-11 14:57:36 +0100133 *
Viet-Hoa Do3389f532023-07-05 17:36:40 +0100134 * This function does not take effect on release build.
135 *
136 * The comment must only contain one line (i.e. no newline character is allowed).
137 *
138 * @param[in] text The comment to be written.
139 */
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +0100140 virtual void op_comment(const std::string &text) = 0;
141
142 /** Write the given raw code to kernel source code
143 * It's used to address the cases where the user needs to
144 * explicitly add a code where it's not (yet) supported by
145 * the kernel writer utility calls.
146 *
147 * @param[in] raw_code raw code to write as string
148 */
149 virtual void op_write_raw_code(const std::string &raw_code) = 0;
Viet-Hoa Do3389f532023-07-05 17:36:40 +0100150
151 // =============================================================================================
152 // Code generation
153 // =============================================================================================
154
155 /** Emit the kernel object.
156 *
157 * @param[in] name The name of the kernel object to be generated.
158 */
159 virtual std::unique_ptr<Kernel> emit_kernel(const std::string &name) = 0;
Gunes Bayirab0b7502023-07-11 14:57:36 +0100160
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +0100161 // =============================================================================================
162 // Tensor and tile declaration
163 // =============================================================================================
164
165 /** Declare a tensor argument.
166 *
167 * @param[in] name The name of the tensor.
168 * @param[in] info The tensor info.
169 *
170 * @return The @ref TensorOperand object.
171 */
172 virtual TensorOperand declare_tensor_argument(const std::string &name, const TensorInfo &info) = 0;
173
Gunes Bayirab0b7502023-07-11 14:57:36 +0100174 /** Declare a tile given its name and tile info
175 *
176 * @param[in] name Name of the tile
177 * @param[in] tile_info Shape and data type of the tile
178 *
Gunes Bayir806b8e82023-08-23 23:28:31 +0100179 * @return The created tile operand
Gunes Bayirab0b7502023-07-11 14:57:36 +0100180 */
Viet-Hoa Do25d26f42023-07-20 17:31:47 +0100181 virtual TileOperand declare_tile(const std::string &name, const TileInfo &tile_info) = 0;
Gunes Bayirab0b7502023-07-11 14:57:36 +0100182
Gunes Bayir806b8e82023-08-23 23:28:31 +0100183 /** Declare a constant tile given a @ref:ConstantData object
184 *
185 * @param[in] data a @ref ckw::ConstantData object that has the values and the
186 * underlying data type of the constant tile
187 *
188 * @return The created constant tile operand
189 */
190 virtual TileOperand declare_constant_tile(const ConstantData &data) = 0;
191
Gunes Bayir47a396e2023-08-17 11:04:02 +0100192 /** Load the data from the tensor memory to the tile using the sampling information.
193 *
194 * @param[in] tile_op The tile to be loaded.
195 * @param[in] tensor_op The tensor to be read.
196 * @param[in] sampler The tensor sampling information.
197 * @param[in] x x-coordinate
198 * @param[in] y y-coordinate
199 * @param[in] z z-coordinate
200 * @param[in] batch batch offset
201 */
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +0100202 virtual void op_load(
203 const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler,
Gunes Bayir47a396e2023-08-17 11:04:02 +0100204 const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch) = 0;
205
206 /** Load the data from the tensor memory to the tile in a dilated way using the sampling information.
207 *
208 * Similar to @ref KernelWriter::op_load() and
209 *
210 * @param[in] dilation_x Dilation while reading in x-dimension
211 * @param[in] dilation_y Dilation while reading in y-dimension
212 */
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +0100213 virtual void op_load_dilated(
214 const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler,
Gunes Bayir47a396e2023-08-17 11:04:02 +0100215 const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch,
216 const TileOperand &dilation_x, const TileOperand &dilation_y) = 0;
217
218 /** Store the data to the tensor memory from the tile using the sampling information.
219 *
220 * Similar to @ref KernelWriter::op_load()
221 */
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +0100222 virtual void op_store(
223 const TensorOperand &tensor_op, const TileOperand &tile_op, TensorSampler &sampler,
Gunes Bayir47a396e2023-08-17 11:04:02 +0100224 const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch) = 0;
225
226 /** Store the data to the tensor memory from the tile in a dilated way using the sampling information.
227 *
228 * Similar to @ref KernelWriter::op_load_dilated()
229 */
Viet-Hoa Doe1c3b462023-07-31 17:13:34 +0100230 virtual void op_store_dilated(
231 const TensorOperand &tensor_op, const TileOperand &tile_op, TensorSampler &sampler,
Gunes Bayir47a396e2023-08-17 11:04:02 +0100232 const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch,
233 const TileOperand &dilation_x, const TileOperand &dilation_y) = 0;
234
Gunes Bayirab0b7502023-07-11 14:57:36 +0100235protected:
236 int32_t id_space() const;
237
Gunes Bayirab0b7502023-07-11 14:57:36 +0100238 /** Generate full variable name by prefixing it with id space */
239 std::string generate_full_name(const std::string &name) const;
240
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +0100241 /** Create a new tile operand referring to the specified tile object. */
Viet-Hoa Do25d26f42023-07-20 17:31:47 +0100242 static TileOperand create_tile_operand(ITile &tile);
243
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +0100244 /** Get the reference to tile object from the tile operand. */
Viet-Hoa Do25d26f42023-07-20 17:31:47 +0100245 static ITile &get_tile(const TileOperand &operand);
246
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +0100247 /** Create a new tensor operand from a tensor object. */
248 static TensorOperand create_tensor_operand(ITensor &tensor);
249
250 /** Get the reference to tensor object from the tensor operand. */
251 static ITensor &get_tensor(const TensorOperand &operand);
252
Gunes Bayir806b8e82023-08-23 23:28:31 +0100253 /** Get the values of a constant data object. */
254 static const std::vector<std::vector<std::string>> &get_values(const ConstantData &data);
255
256 /** Get the data type of a constant data object. */
257 static DataType get_data_type(const ConstantData &data);
258
Gunes Bayirab0b7502023-07-11 14:57:36 +0100259private:
260 int32_t _id_space{ 0 };
Viet-Hoa Do3389f532023-07-05 17:36:40 +0100261};
262
263} // namespace ckw
264
265#endif // CKW_INCLUDE_CKW_KERNELWRITER_H