blob: d1aefbbb718ca7e94865e7cca884a1fa8de4730c [file] [log] [blame]
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +01001/*
2 * Copyright (c) 2023 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#include "ckw/TensorOperand.h"
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010026
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +010027#include "ckw/Error.h"
28#include "ckw/Kernel.h"
Viet-Hoa Doc8e16172023-06-27 14:09:46 +010029#include "ckw/TensorInfo.h"
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +010030#include "ckw/TileOperand.h"
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010031
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +010032#include "src/Prototype.h"
33
34namespace ckw
35{
36
37namespace
38{
39
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010040TensorComponentOperand &get_or_create_component(TensorOperand &tensor,
41 std::unique_ptr<TensorComponentOperand> &ptr,
42 TensorComponentType component)
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +010043{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010044 if (ptr == nullptr)
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +010045 {
Viet-Hoa Doc8e16172023-06-27 14:09:46 +010046 ptr = std::make_unique<TensorComponentOperand>(tensor, component);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +010047 }
48
49 return *ptr;
50}
51
52} // namespace
53
54// =================================================================================================
55// TensorOperand
56// =================================================================================================
57
Viet-Hoa Doc8e16172023-06-27 14:09:46 +010058TensorOperand::TensorOperand(const std::string &name, const TensorInfo &info, TensorStorageType storage_type)
59 : OperandBase(name), _info(info), _storage_type(storage_type)
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +010060{
61}
62
63prototype::Operand TensorOperand::create_impl_operand(prototype::IGpuKernelWriter *writer) const
64{
65 CKW_UNUSED(writer);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010066 return {name()};
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +010067}
68
69const TensorInfo &TensorOperand::info() const
70{
71 return _info;
72}
73
74TensorInfo &TensorOperand::info()
75{
76 return _info;
77}
78
Viet-Hoa Doc8e16172023-06-27 14:09:46 +010079TensorStorageType TensorOperand::storage_type() const
80{
81 return _storage_type;
82}
83
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +010084DataType TensorOperand::data_type() const
85{
86 return _info.data_type();
87}
88
89bool TensorOperand::is_constant() const
90{
91 return false;
92}
93
94const TileOperand &TensorOperand::tile() const
95{
96 return *_tile;
97}
98
99TileOperand &TensorOperand::tile()
100{
101 return *_tile;
102}
103
104TensorOperand &TensorOperand::tile(TileOperand &tile)
105{
106 _tile = &tile;
107 return *this;
108}
109
110const TensorTileSampler &TensorOperand::tile_sampler() const
111{
112 return _tile_sampler;
113}
114
115TensorTileSampler &TensorOperand::tile_sampler()
116{
117 return _tile_sampler;
118}
119
120TensorOperand &TensorOperand::tile_sampler(const TensorTileSampler &value)
121{
122 _tile_sampler = value;
123 return *this;
124}
125
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100126TensorComponentOperand &TensorOperand::stride1()
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100127{
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100128 return get_or_create_component(*this, _stride1, TensorComponentType::Stride1);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100129}
130
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100131TensorComponentOperand &TensorOperand::stride2()
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100132{
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100133 return get_or_create_component(*this, _stride2, TensorComponentType::Stride2);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100134}
135
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100136TensorComponentOperand &TensorOperand::stride3()
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100137{
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100138 return get_or_create_component(*this, _stride3, TensorComponentType::Stride3);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100139}
140
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100141TensorComponentOperand &TensorOperand::stride4()
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100142{
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100143 return get_or_create_component(*this, _stride4, TensorComponentType::Stride4);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100144}
145
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100146TensorComponentOperand &TensorOperand::dim0()
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100147{
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100148 return get_or_create_component(*this, _dim0, TensorComponentType::Dim0);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100149}
150
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100151TensorComponentOperand &TensorOperand::dim1()
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100152{
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100153 return get_or_create_component(*this, _dim1, TensorComponentType::Dim1);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100154}
155
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100156TensorComponentOperand &TensorOperand::dim2()
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100157{
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100158 return get_or_create_component(*this, _dim2, TensorComponentType::Dim2);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100159}
160
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100161TensorComponentOperand &TensorOperand::dim3()
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100162{
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100163 return get_or_create_component(*this, _dim3, TensorComponentType::Dim3);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100164}
165
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100166TensorComponentOperand &TensorOperand::dim4()
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100167{
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100168 return get_or_create_component(*this, _dim4, TensorComponentType::Dim4);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100169}
170
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100171TensorComponentOperand &TensorOperand::dim1_dim2()
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100172{
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100173 return get_or_create_component(*this, _dim1_dim2, TensorComponentType::Dim1xDim2);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100174}
175
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100176TensorComponentOperand &TensorOperand::dim1_dim2_dim3()
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100177{
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100178 return get_or_create_component(*this, _dim1_dim2_dim3, TensorComponentType::Dim1xDim2xDim3);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100179}
180
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100181TensorComponentOperand &TensorOperand::offset_first_element_in_bytes()
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100182{
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100183 return get_or_create_component(*this, _offset_first_element_in_bytes, TensorComponentType::OffsetFirstElement);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100184}
185
186// =================================================================================================
187// TensorComponentOperand
188// =================================================================================================
189
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100190TensorComponentOperand::TensorComponentOperand(TensorOperand &tensor, TensorComponentType component)
191 : TileOperand(tensor.name(), DataType::Int32), _tensor(tensor), _component(component)
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100192{
193}
194
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100195TensorOperand &TensorComponentOperand::tensor()
196{
197 return _tensor;
198}
199
200const TensorOperand &TensorComponentOperand::tensor() const
201{
202 return _tensor;
203}
204
205TensorComponentType TensorComponentOperand::component_type() const
206{
207 return _component;
208}
209
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100210prototype::Operand TensorComponentOperand::create_impl_operand(prototype::IGpuKernelWriter *writer) const
211{
212 CKW_UNUSED(writer);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100213 prototype::OperandType type{prototype::OperandType::Unknown};
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100214
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100215 switch (_component)
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100216 {
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100217 case TensorComponentType::OffsetFirstElement:
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100218 type = prototype::OperandType::TensorDataOffset;
219 break;
220
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100221 case TensorComponentType::Stride1:
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100222 type = prototype::OperandType::TensorStride1;
223 break;
224
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100225 case TensorComponentType::Stride2:
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100226 type = prototype::OperandType::TensorStride2;
227 break;
228
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100229 case TensorComponentType::Stride3:
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100230 type = prototype::OperandType::TensorStride3;
231 break;
232
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100233 case TensorComponentType::Stride4:
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100234 type = prototype::OperandType::TensorStride4;
235 break;
236
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100237 case TensorComponentType::Dim0:
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100238 type = prototype::OperandType::TensorDim0;
239 break;
240
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100241 case TensorComponentType::Dim1:
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100242 type = prototype::OperandType::TensorDim1;
243 break;
244
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100245 case TensorComponentType::Dim2:
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100246 type = prototype::OperandType::TensorDim2;
247 break;
248
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100249 case TensorComponentType::Dim3:
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100250 type = prototype::OperandType::TensorDim3;
251 break;
252
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100253 case TensorComponentType::Dim4:
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100254 type = prototype::OperandType::TensorDim4;
255 break;
256
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100257 case TensorComponentType::Dim1xDim2:
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100258 type = prototype::OperandType::TensorDim1xDim2;
259 break;
260
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100261 case TensorComponentType::Dim1xDim2xDim3:
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100262 type = prototype::OperandType::TensorDim1xDim2xDim3;
263 break;
264
265 default:
266 CKW_ASSERT(false);
267 }
268
269 return prototype::Operand(name(), type);
270}
271
272} // namespace ckw