blob: 6a9884543c5f3ed499aab1365c50468183a88ac9 [file] [log] [blame]
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +01001/*
2 * Copyright (c) 2023 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#include "ckw/Error.h"
Viet-Hoa Doc8e16172023-06-27 14:09:46 +010026#include "ckw/KernelArgument.h"
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +010027#include "ckw/KernelWriter.h"
28#include "ckw/TensorOperand.h"
29#include "ckw/TensorTileSampler.h"
30#include "ckw/TileOperand.h"
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +010031
Viet-Hoa Doce3c48c2023-07-03 13:44:43 +010032#include "common/ExampleComponentArgument.h"
33#include "common/ExampleKernelWriter.h"
34#include "common/ExampleScopedKernelWriter.h"
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +010035
36#include <iostream>
37#include <vector>
38
39using namespace ckw;
40
Viet-Hoa Doce3c48c2023-07-03 13:44:43 +010041TensorTileSampler create_simple_sampler(ExampleScopedKernelWriter writer)
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +010042{
43 TensorTileSampler sampler;
44
45 constexpr int32_t m0 = 4;
46 constexpr int32_t n0 = 4;
47
48 auto &gid_0 = writer->declare_tile("gid_0", DataType::Int32);
49 auto &gid_1 = writer->declare_tile("gid_1", DataType::Int32);
50 auto &gid_2 = writer->declare_tile("gid_2", DataType::Int32);
51
52 auto &const_0 = writer->declare_tile("0", 0);
53
54 writer->op_get_global_id(gid_0, 0);
55 writer->op_get_global_id(gid_1, 1);
56 writer->op_get_global_id(gid_2, 2);
57
58 sampler.x(gid_0);
59 sampler.y(gid_1);
Viet-Hoa Doce3c48c2023-07-03 13:44:43 +010060 sampler.z(const_0);
61 sampler.b(gid_2);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +010062
63 sampler.width(n0);
64 sampler.height(m0);
65
66 sampler.format(TensorSamplerFormat::C_WH_1);
67 sampler.address_mode_x(TensorSamplerAddressModeX::None);
68 sampler.address_mode_y(TensorSamplerAddressModeY::ClampToBorder);
69 sampler.address_mode_z(TensorSamplerAddressModeZ::Skip);
70
71 return sampler;
72}
73
Viet-Hoa Doce3c48c2023-07-03 13:44:43 +010074void op_binary_elementwise(ExampleScopedKernelWriter writer, std::vector<ExampleComponentArgument *> operands)
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +010075{
76 auto lhs = operands.at(0);
77 auto rhs = operands.at(1);
78 auto dst = operands.at(2);
79
80 // Load the LHS and RHS tile and prepare the tensor sampler.
81 if(!lhs->has_tile() && !rhs->has_tile())
82 {
83 const auto sampler = create_simple_sampler(writer);
84
85 writer->op_load_once(lhs, sampler);
86 writer->op_load_once(rhs, sampler);
87 }
88 else if(lhs->has_tile())
89 {
90 const auto &sampler = lhs->tile_sampler();
91 writer->op_load_once(rhs, sampler);
92 }
93 else
94 {
95 const auto &sampler = rhs->tile_sampler();
96 writer->op_load_once(lhs, sampler);
97 }
98
99 auto &lhs_tile = lhs->tile();
100 auto &rhs_tile = rhs->tile();
101 const auto &sampler = lhs->tile_sampler();
102
103 // Prepare the output tile.
104 if(!dst->has_tile())
105 {
106 auto &tile = writer->declare_tile("dst_tile", lhs_tile.tile_info());
107 dst->init_virtual_tensor(tile, sampler);
108 }
109
110 auto &dst_tile = dst->tile();
111
112 // Perform the operation.
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100113 writer->op_binary_expression(dst_tile, lhs_tile, BinaryOp::Add, rhs_tile);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100114}
115
Viet-Hoa Doce3c48c2023-07-03 13:44:43 +0100116void op_exp(ExampleScopedKernelWriter writer, std::vector<ExampleComponentArgument *> operands)
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100117{
118 auto src = operands.at(0);
119 auto dst = operands.at(1);
120
121 // Load the source tile and prepare the sampler.
122 if(!src->has_tile())
123 {
124 const auto sampler = create_simple_sampler(writer);
125 writer->op_load_once(src, sampler);
126 }
127
128 auto &src_tile = src->tile();
129 const auto &sampler = src->tile_sampler();
130
131 // Prepare the output tile.
132 if(!dst->has_tile())
133 {
134 auto &tile = writer->declare_tile("dst_tile", src_tile.tile_info());
135 dst->init_virtual_tensor(tile, sampler);
136 }
137
138 auto &dst_tile = dst->tile();
139
140 // Perform the operation.
Nikolaj Jensen5ff48022023-06-27 14:13:24 +0100141 writer->op_unary_elementwise_function(dst_tile, UnaryFunction::Exp, src_tile);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100142}
143
Viet-Hoa Doce3c48c2023-07-03 13:44:43 +0100144void op_store(ExampleScopedKernelWriter writer, std::vector<ExampleComponentArgument *> operands)
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100145{
146 auto src = operands.at(0);
147 auto dst = operands.at(1);
148
149 auto &src_tile = src->tile();
150 const auto &sampler = src->tile_sampler();
151 auto &dst_tensor = dst->tensor();
152
153 writer->op_store(dst_tensor, src_tile, sampler);
154}
155
156int main()
157{
Nikolaj Jensenacea4072023-07-03 09:44:42 +0100158 Kernel kernel("example", GpuTargetLanguage::OpenCL);
Viet-Hoa Doce3c48c2023-07-03 13:44:43 +0100159 ExampleKernelWriter root_writer(kernel);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100160
Viet-Hoa Doce3c48c2023-07-03 13:44:43 +0100161 ExampleScopedKernelWriter writer(&root_writer);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100162
163 const TensorInfo src0_info(DataType::Fp32, TensorShape({ 3, 10, 20, 1, 1 }), TensorDataLayout::Nhwc, 0);
164 const TensorInfo src1_info(DataType::Fp32, TensorShape({ 3, 10, 20, 1, 1 }), TensorDataLayout::Nhwc, 1);
165 const TensorInfo dst_info(DataType::Fp32, TensorShape({ 3, 10, 20, 1, 1 }), TensorDataLayout::Nhwc, 2);
166
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100167 ExampleComponentArgument src0(writer->declare_tensor_argument("src0", src0_info, TensorStorageType::BufferUint8Ptr));
168 ExampleComponentArgument src1(writer->declare_tensor_argument("src1", src1_info, TensorStorageType::BufferUint8Ptr));
169 ExampleComponentArgument dst(writer->declare_tensor_argument("dst", dst_info, TensorStorageType::BufferUint8Ptr));
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100170
Viet-Hoa Doce3c48c2023-07-03 13:44:43 +0100171 ExampleComponentArgument ans;
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100172
173 op_binary_elementwise(writer, { &src0, &src1, &ans });
174 op_exp(writer, { &ans, &ans });
175 op_store(writer, { &ans, &dst });
176
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100177 const auto arguments = kernel.arguments();
178
179 std::cout << "\n====================\nArguments:\n====================\n";
180
181 for(auto &arg : arguments)
182 {
183 switch(arg.type())
184 {
185 case ckw::KernelArgument::Type::TensorStorage:
186 std::cout << "* Tensor storage: ID = " << arg.id() << ", type = " << std::hex << "0x" << static_cast<uint32_t>(arg.tensor_storage_type()) << std::dec << "\n";
187 break;
188
189 case ckw::KernelArgument::Type::TensorComponent:
190 std::cout << "* Tensor component: ID = " << arg.id() << ", type = " << std::hex << "0x" << static_cast<uint32_t>(arg.tensor_component_type()) << std::dec << "\n";
191 break;
192
193 default:
194 CKW_ASSERT(false);
195 }
196 }
197
198 std::cout << "\n====================\nCode:\n====================\n";
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100199 const auto code = root_writer.generate_code();
200 std::cout << code;
201
202 return 0;
203}