blob: ea360b289e42c5c0e793d76e9837baf9554bb2e8 [file] [log] [blame]
Viet-Hoa Docd1f03e2023-09-19 16:41:34 +01001/*
2 * Copyright (c) 2023 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#ifndef CKW_VALIDATION_SRC_TESTS_CLKERNELWRITERSUBTILETEST_H
26#define CKW_VALIDATION_SRC_TESTS_CLKERNELWRITERSUBTILETEST_H
27
28#include "ckw/TileInfo.h"
29#include "ckw/types/DataType.h"
30#include "ckw/types/Operators.h"
31#include "src/cl/CLKernelWriter.h"
32#include "validation/tests/common/Common.h"
33#include "validation/tests/common/KernelWriterInterceptor.h"
34
35#include <cstdint>
36#include <vector>
37
38namespace ckw
39{
40
41class CLKernelWriterSubTileTest : public ITest
42{
43public:
44 CLKernelWriterSubTileTest()
45 {
46 // These are the definitions of the tiles involving in the writing actions.
47 //
48 // Structure:
49 // * List of tiles:
50 // - Tile full height.
51 // - Tile full width.
52 // - Tile view access type (full tile, vector, scalar).
53 // - Tile view start row.
54 // - Tile view start column.
55 // - The tile name.
56
57 // Vector access.
58 _tests.push_back(
59 { { { 1, 4, AccessType::Vector, 0, 0, "{tile_name}" }, //
60 { 4, 4, AccessType::Vector, 2, 0, "{tile_name}__2" },
61 { 1, 4, AccessType::Full, 0, 0, "{tile_name}" },
62 { 4, 4, AccessType::Vector, 3, 0, "{tile_name}__3" } } });
63
64 // Scalar access.
65 _tests.push_back(
66 { { { 1, 1, AccessType::Full, 0, 0, "{tile_name}" }, //
67 { 4, 8, AccessType::Scalar, 2, 4, "{tile_name}__2.s4" },
68 { 1, 16, AccessType::ScalarOfVector, 0, 10, "{tile_name}.sA" },
69 { 1, 1, AccessType::Scalar, 0, 0, "{tile_name}" } } });
70
71 // These are the definitions of the writing actions.
72 //
73 // Structure:
74 // * Writing function.
75 // * Whether this function only works with scalar value.
76 // * Expected code format.
77
78 _actions.push_back(
79 { [](CLKernelWriter &writer, const std::vector<TileOperand> &args)
80 {
81 writer.op_assign(args.at(0), args.at(1));
82 },
83 false,
84 "{op0} = {op1};\n" });
85
86 _actions.push_back(
87 { [](CLKernelWriter &writer, const std::vector<TileOperand> &args)
88 {
89 writer.op_unary(args.at(0), UnaryOp::Sqrt, args.at(1));
90 },
91 false,
92 "{op0} = sqrt({op1});\n" });
93
94 _actions.push_back(
95 { [](CLKernelWriter &writer, const std::vector<TileOperand> &args)
96 {
97 writer.op_binary(args.at(0), BinaryOp::Add, args.at(1), args.at(2));
98 },
99 false,
100 "{op0} = {op1} + {op2};\n" });
101
102 _actions.push_back(
103 { [](CLKernelWriter &writer, const std::vector<TileOperand> &args)
104 {
105 writer.op_ternary(args.at(0), TernaryOp::Clamp, args.at(1), args.at(2), args.at(3));
106 },
107 false,
108 "{op0} = clamp({op1}, {op2}, {op3});\n" });
109
110 _actions.push_back(
111 { [](CLKernelWriter &writer, const std::vector<TileOperand> &args)
112 {
113 writer.op_if(args.at(0), BinaryOp::Greater, args.at(1), [] {});
114 },
115 true,
116 "if ({op0} > {op1})\n{\n}\n" });
117 }
118
119 bool run() override
120 {
121 bool all_tests_passed = true;
122 int32_t test_idx = 0;
123
124 KernelWriterInterceptor<CLKernelWriter> writer;
125
126 for(size_t test_no = 0; test_no < _tests.size(); ++test_no)
127 {
128 const TestInfo &test = _tests.at(test_no);
129
130 // Declare all the tiles and get the full name of those tile operand.
131 std::vector<TileOperand> tiles;
132 std::vector<std::string> expected_tiles_name;
133
134 for(size_t operand_no = 0; operand_no < test.operands.size(); ++operand_no)
135 {
136 const TestOperand &operand = test.operands.at(operand_no);
137 std::string name = "test" + std::to_string(test_no) + "_op" + std::to_string(operand_no);
138
139 const TileOperand full_tile = writer.declare_tile(name, TileInfo(DataType::Fp32, operand.height, operand.width));
140
141 switch(operand.access_type)
142 {
143 case AccessType::Full:
144 tiles.emplace_back(full_tile);
145 break;
146
147 case AccessType::Vector:
148 tiles.emplace_back(full_tile.row(operand.start_row));
149 break;
150
151 case AccessType::Scalar:
152 tiles.emplace_back(full_tile.scalar(operand.start_row, operand.start_col));
153 break;
154
155 case AccessType::ScalarOfVector:
156 tiles.emplace_back(full_tile.row(operand.start_row).scalar(0, operand.start_col));
157 break;
158
159 default:
160 CKW_THROW_MSG("Unsupported access type!");
161 }
162
163 expected_tiles_name.push_back("G0__" + name);
164 }
165
166 // Try each writing action using the newly declared tiles.
167 for(const TestAction &action : _actions)
168 {
169 if(action.scalar_only && //
170 (test.operands.at(0).access_type != AccessType::Scalar && //
171 (test.operands.at(0).height != 1 || test.operands.at(0).width != 1)))
172 {
173 continue;
174 }
175
176 writer.start_capture_code();
177
178 action.write(writer, tiles);
179
180 // The expected code is constructed from the format strings.
181 std::string expected_code = action.expected_code;
182
183 for(size_t operand_no = 0; operand_no < test.operands.size(); ++operand_no)
184 {
185 const TestOperand &operand = test.operands.at(operand_no);
186
187 const std::string op_name = search_and_replace(operand.name, "{tile_name}", expected_tiles_name.at(operand_no));
188 expected_code = search_and_replace(expected_code, "{op" + std::to_string(operand_no) + "}", op_name);
189 }
190
191 VALIDATE_TEST(writer.check_added_code(expected_code), all_tests_passed, test_idx++);
192 }
193 }
194
195 return all_tests_passed;
196 }
197
198 std::string search_and_replace(const std::string &src, const std::string &search, const std::string &replace)
199 {
200 std::string result = src;
201
202 size_t idx = 0;
203
204 while(true)
205 {
206 idx = result.find(search, idx);
207
208 if(idx == std::string::npos)
209 {
210 break;
211 }
212
213 result = result.replace(idx, search.size(), replace);
214 }
215
216 return result;
217 }
218
219 std::string name() override
220 {
221 return "CLKernelWriterSubTileTest";
222 }
223
224private:
225 enum class AccessType
226 {
227 Full,
228 Vector,
229 Scalar,
230 ScalarOfVector,
231 };
232
233 struct TestOperand
234 {
235 int32_t height;
236 int32_t width;
237
238 AccessType access_type;
239 int32_t start_row;
240 int32_t start_col;
241
242 std::string name;
243 };
244
245 struct TestInfo
246 {
247 std::vector<TestOperand> operands;
248 };
249
250 struct TestAction
251 {
252 std::function<void(CLKernelWriter &, const std::vector<TileOperand> &)> write;
253
254 bool scalar_only;
255 std::string expected_code;
256 };
257
258 std::vector<TestInfo> _tests{};
259 std::vector<TestAction> _actions{};
260};
261
262} // namespace ckw
263
264#endif // CKW_VALIDATION_SRC_TESTS_CLKERNELWRITERSUBTILETEST_H