blob: 73458efa1dd48cb75c392ca08dec6eca44a38278 [file] [log] [blame]
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +01001/*
2 * Copyright (c) 2023 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#include "ckw/KernelWriter.h"
26#include "ckw/Error.h"
27#include "ckw/TensorOperand.h"
28#include "src/Prototype.h"
29
30#include <sstream>
31
32namespace ckw
33{
34
35namespace
36{
37
38inline prototype::TensorInfo create_impl_tensor_info(const TensorInfo &info)
39{
40 return prototype::TensorInfo{ info.shape(), info.data_type(), info.data_layout(), info.id() };
41}
42
43} // namespace
44
45// =================================================================================================
46// Constructors and destructor
47// =================================================================================================
48
49KernelWriter::KernelWriter(Kernel &kernel)
50 : _kernel(&kernel),
51 _impl_attr(std::make_unique<prototype::GpuKernelWriterAttribute>()),
52 _impl(prototype::GpuKernelWriterFactory::create(_impl_attr.get(), kernel.impl()))
53{
54 _impl->set_IdSpace(1);
55}
56
57KernelWriter::~KernelWriter()
58{
59}
60
61// =================================================================================================
62// Scope management
63// =================================================================================================
64
65int32_t KernelWriter::id_space() const
66{
67 return _id_space;
68}
69
70KernelWriter &KernelWriter::id_space(int32_t id_space)
71{
72 CKW_ASSERT(id_space <= _max_id_space);
73
74 _id_space = id_space;
75 return *this;
76}
77
78int32_t KernelWriter::next_id_space()
79{
80 id_space(++_max_id_space);
81 return _id_space;
82}
83
84// =================================================================================================
85// Tensor and tile declaration
86// =================================================================================================
87
Nikolaj Jensen5ff48022023-06-27 14:13:24 +010088TensorOperand &KernelWriter::declare_tensor_argument(const std::string &name, const TensorInfo &info)
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +010089{
90 const auto var_name = generate_variable_name(name);
91
92 _impl->declare_argument(var_name, create_impl_tensor_info(info));
93
94 auto operand = new TensorOperand(var_name, info);
95 register_operand(operand, false);
96
97 return *operand;
98}
99
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100100TileOperand &KernelWriter::declare_tile_argument(const std::string &name, int32_t value)
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100101{
102 const auto var_name = generate_variable_name(name);
103
104 auto operand = new TileOperand(var_name, value);
105 register_operand(operand, false);
106
107 return *operand;
108}
109
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100110std::string KernelWriter::generate_variable_name(const std::string &name) const
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100111{
112 std::stringstream var_name;
113
114 var_name << "_" << _id_space << "_" << name;
115
116 return var_name.str();
117}
118
119void KernelWriter::register_operand(OperandBase *operand, bool declaring)
120{
121 const auto &name = operand->name();
122 auto &operands = _kernel->operands();
123
124 CKW_ASSERT(operands.find(name) == operands.end());
125 operands[name] = std::unique_ptr<OperandBase>(operand);
126
127 if(declaring && !operand->is_constant())
128 {
129 const auto tile = reinterpret_cast<TileOperand *>(operand);
130
131 const auto &info = tile->tile_info();
132 _impl->declare_tile(tile->name(), prototype::TileInfo(info.data_type(), info.width(), info.height()));
133 }
134}
135
136// =================================================================================================
137// Load and store
138// =================================================================================================
139
140void KernelWriter::op_load(TileOperand &tile, TensorOperand &tensor, const TensorTileSampler &sampler)
141{
Viet-Hoa Doe1880f02023-06-28 10:25:35 +0100142 prototype::TensorOperand impl_tensor(
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100143 tensor.name(),
144 prototype::GpuSampler{
145 sampler.format(),
146 prototype::GpuSamplerTensorStorage::BufferUint8Ptr,
147 sampler.address_mode_x(),
148 sampler.address_mode_y(),
149 sampler.address_mode_z() });
150
151 auto impl_x = sampler.x().create_impl_operand(_impl.get());
152 auto impl_y = sampler.y().create_impl_operand(_impl.get());
153 auto impl_z = sampler.z().create_impl_operand(_impl.get());
154 auto impl_b = sampler.b().create_impl_operand(_impl.get());
155
156 auto impl_dst = tile.create_impl_operand(_impl.get());
157
158 _impl->op_load_immediate(impl_tensor, impl_dst, impl_x, impl_y, impl_z, impl_b);
159}
160
161void KernelWriter::op_store(TensorOperand &tensor, const TileOperand &tile, const TensorTileSampler &sampler)
162{
Viet-Hoa Doe1880f02023-06-28 10:25:35 +0100163 prototype::TensorOperand impl_tensor(
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100164 tensor.name(),
165 prototype::GpuSampler{
166 sampler.format(),
167 prototype::GpuSamplerTensorStorage::BufferUint8Ptr,
168 sampler.address_mode_x(),
169 sampler.address_mode_y(),
170 sampler.address_mode_z() });
171 auto impl_src = tile.create_impl_operand(_impl.get());
172 auto impl_x = sampler.x().create_impl_operand(_impl.get());
173 auto impl_y = sampler.y().create_impl_operand(_impl.get());
174 auto impl_z = sampler.z().create_impl_operand(_impl.get());
175 auto impl_b = sampler.b().create_impl_operand(_impl.get());
176
177 _impl->op_store_immediate(impl_tensor, impl_src, impl_x, impl_y, impl_z, impl_b);
178}
179
180// =================================================================================================
181// Data processing
182// =================================================================================================
183
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100184void KernelWriter::op_assign(const TileOperand &dst, const TileOperand &src)
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100185{
186 auto impl_dst = dst.create_impl_operand(_impl.get());
187 auto impl_src = src.create_impl_operand(_impl.get());
188
189 _impl->op_assign(impl_dst, impl_src);
190}
191
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100192void KernelWriter::op_cast_expression(const TileOperand &dst, const TileOperand &src, const ConvertPolicy policy)
193{
194 auto impl_dst = dst.create_impl_operand(_impl.get());
195 auto impl_src = src.create_impl_operand(_impl.get());
196
197 _impl->op_cast_expression(impl_dst, impl_src, policy);
198}
199
200void KernelWriter::op_binary_expression(const TileOperand &dst, const TileOperand &lhs, BinaryOp op, const TileOperand &rhs)
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100201{
202 auto impl_lhs = lhs.create_impl_operand(_impl.get());
203 auto impl_rhs = rhs.create_impl_operand(_impl.get());
204 auto impl_dst = dst.create_impl_operand(_impl.get());
205
206 _impl->op_binary_expression(impl_dst, impl_lhs, op, impl_rhs);
207}
208
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100209void KernelWriter::op_unary_expression(const TileOperand &dst, UnaryOp op, const TileOperand &src)
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100210{
211 auto impl_dst = dst.create_impl_operand(_impl.get());
212 auto impl_src = src.create_impl_operand(_impl.get());
213
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100214 _impl->op_unary_expression(impl_dst, op, impl_src);
215}
216
217void KernelWriter::op_unary_elementwise_function(const TileOperand &dst, UnaryFunction opcode, const TileOperand &src)
218{
219 auto impl_dst = dst.create_impl_operand(_impl.get());
220 auto impl_src = src.create_impl_operand(_impl.get());
221
222 _impl->op_unary_elementwise_function(impl_dst, opcode, impl_src);
223}
224
225void KernelWriter::op_binary_elementwise_function(const TileOperand &dst, BinaryFunction opcode, const TileOperand &first, const TileOperand &second)
226{
227 auto impl_dst = dst.create_impl_operand(_impl.get());
228 auto impl_first = first.create_impl_operand(_impl.get());
229 auto impl_second = second.create_impl_operand(_impl.get());
230
231 _impl->op_binary_elementwise_function(impl_dst, opcode, impl_first, impl_second);
232}
233
234void KernelWriter::op_ternary_elementwise_function(const TileOperand &dst, TernaryFunction opcode, const TileOperand &first, const TileOperand &second, const TileOperand &third)
235{
236 auto impl_dst = dst.create_impl_operand(_impl.get());
237 auto impl_first = first.create_impl_operand(_impl.get());
238 auto impl_second = second.create_impl_operand(_impl.get());
239 auto impl_third = third.create_impl_operand(_impl.get());
240
241 _impl->op_ternary_elementwise_function(impl_dst, opcode, impl_first, impl_second, impl_third);
242}
243
244void KernelWriter::op_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body)
245{
246 auto impl_lhs = lhs.create_impl_operand(_impl.get());
247 auto impl_rhs = rhs.create_impl_operand(_impl.get());
248
249 _impl->op_if_header(impl_lhs, op, impl_rhs);
250 _impl->compound_statement_begin();
251 body();
252 _impl->compound_statement_end();
253}
254
255void KernelWriter::op_else_if(const TileOperand &lhs, BinaryOp op, const TileOperand &rhs, const std::function<void()> &body)
256{
257 auto impl_lhs = lhs.create_impl_operand(_impl.get());
258 auto impl_rhs = rhs.create_impl_operand(_impl.get());
259
260 _impl->op_else_if_header(impl_lhs, op, impl_rhs);
261 _impl->compound_statement_begin();
262 body();
263 _impl->compound_statement_end();
264}
265
266void KernelWriter::op_else(const std::function<void()> &body)
267{
268 _impl->op_else_header();
269 _impl->compound_statement_begin();
270 body();
271 _impl->compound_statement_end();
272}
273
274void KernelWriter::op_for_loop(const TileOperand &var_name, BinaryOp cond_op, const TileOperand &cond_value_name, AssignmentOp update_op, const TileOperand &update_value_name, const std::function<void()> &body)
275{
276 auto impl_var_name = var_name.create_impl_operand(_impl.get());
277 auto impl_cond_value_name = cond_value_name.create_impl_operand(_impl.get());
278 auto impl_update_value_name = update_value_name.create_impl_operand(_impl.get());
279
280 _impl->op_for_loop_header(impl_var_name, cond_op, impl_cond_value_name, update_op, impl_update_value_name);
281 _impl->compound_statement_begin();
282 body();
283 _impl->compound_statement_end();
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100284}
285
286// =================================================================================================
287// Misc
288// =================================================================================================
289
290void KernelWriter::op_get_global_id(TileOperand &dst, int32_t dim)
291{
292 _impl->op_get_global_id(prototype::Operand(dst.name()), dim);
293}
294
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100295void KernelWriter::op_return()
296{
297 _impl->op_return();
298}
299
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100300// =================================================================================================
301// Code generation
302// =================================================================================================
303
304std::string KernelWriter::generate_code()
305{
306 return prototype::generate_code(*_kernel->impl(), _kernel->name());
307}
308
309} // namespace ckw