blob: c6725d3b26c2f5c4d398fa2cfc7e16785a3bd107 [file] [log] [blame]
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +01001/*
2 * Copyright (c) 2023 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#include "ckw/TensorOperand.h"
26#include "ckw/Error.h"
27#include "ckw/Kernel.h"
Viet-Hoa Doc8e16172023-06-27 14:09:46 +010028#include "ckw/TensorInfo.h"
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +010029#include "ckw/TileOperand.h"
30#include "src/Prototype.h"
31
32namespace ckw
33{
34
35namespace
36{
37
Viet-Hoa Doc8e16172023-06-27 14:09:46 +010038TensorComponentOperand &get_or_create_component(TensorOperand &tensor, std::unique_ptr<TensorComponentOperand> &ptr, TensorComponentType component)
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +010039{
40 if(ptr == nullptr)
41 {
Viet-Hoa Doc8e16172023-06-27 14:09:46 +010042 ptr = std::make_unique<TensorComponentOperand>(tensor, component);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +010043 }
44
45 return *ptr;
46}
47
48} // namespace
49
50// =================================================================================================
51// TensorOperand
52// =================================================================================================
53
Viet-Hoa Doc8e16172023-06-27 14:09:46 +010054TensorOperand::TensorOperand(const std::string &name, const TensorInfo &info, TensorStorageType storage_type)
55 : OperandBase(name), _info(info), _storage_type(storage_type)
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +010056{
57}
58
59prototype::Operand TensorOperand::create_impl_operand(prototype::IGpuKernelWriter *writer) const
60{
61 CKW_UNUSED(writer);
62 return { name() };
63}
64
65const TensorInfo &TensorOperand::info() const
66{
67 return _info;
68}
69
70TensorInfo &TensorOperand::info()
71{
72 return _info;
73}
74
Viet-Hoa Doc8e16172023-06-27 14:09:46 +010075TensorStorageType TensorOperand::storage_type() const
76{
77 return _storage_type;
78}
79
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +010080DataType TensorOperand::data_type() const
81{
82 return _info.data_type();
83}
84
85bool TensorOperand::is_constant() const
86{
87 return false;
88}
89
90const TileOperand &TensorOperand::tile() const
91{
92 return *_tile;
93}
94
95TileOperand &TensorOperand::tile()
96{
97 return *_tile;
98}
99
100TensorOperand &TensorOperand::tile(TileOperand &tile)
101{
102 _tile = &tile;
103 return *this;
104}
105
106const TensorTileSampler &TensorOperand::tile_sampler() const
107{
108 return _tile_sampler;
109}
110
111TensorTileSampler &TensorOperand::tile_sampler()
112{
113 return _tile_sampler;
114}
115
116TensorOperand &TensorOperand::tile_sampler(const TensorTileSampler &value)
117{
118 _tile_sampler = value;
119 return *this;
120}
121
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100122TensorComponentOperand &TensorOperand::stride1()
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100123{
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100124 return get_or_create_component(*this, _stride1, TensorComponentType::Stride1);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100125}
126
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100127TensorComponentOperand &TensorOperand::stride2()
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100128{
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100129 return get_or_create_component(*this, _stride2, TensorComponentType::Stride2);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100130}
131
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100132TensorComponentOperand &TensorOperand::stride3()
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100133{
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100134 return get_or_create_component(*this, _stride3, TensorComponentType::Stride3);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100135}
136
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100137TensorComponentOperand &TensorOperand::stride4()
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100138{
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100139 return get_or_create_component(*this, _stride4, TensorComponentType::Stride4);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100140}
141
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100142TensorComponentOperand &TensorOperand::dim0()
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100143{
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100144 return get_or_create_component(*this, _dim0, TensorComponentType::Dim0);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100145}
146
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100147TensorComponentOperand &TensorOperand::dim1()
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100148{
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100149 return get_or_create_component(*this, _dim1, TensorComponentType::Dim1);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100150}
151
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100152TensorComponentOperand &TensorOperand::dim2()
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100153{
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100154 return get_or_create_component(*this, _dim2, TensorComponentType::Dim2);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100155}
156
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100157TensorComponentOperand &TensorOperand::dim3()
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100158{
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100159 return get_or_create_component(*this, _dim3, TensorComponentType::Dim3);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100160}
161
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100162TensorComponentOperand &TensorOperand::dim4()
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100163{
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100164 return get_or_create_component(*this, _dim4, TensorComponentType::Dim4);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100165}
166
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100167TensorComponentOperand &TensorOperand::dim1_dim2()
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100168{
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100169 return get_or_create_component(*this, _dim1_dim2, TensorComponentType::Dim1xDim2);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100170}
171
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100172TensorComponentOperand &TensorOperand::dim1_dim2_dim3()
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100173{
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100174 return get_or_create_component(*this, _dim1_dim2_dim3, TensorComponentType::Dim1xDim2xDim3);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100175}
176
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100177TensorComponentOperand &TensorOperand::offset_first_element_in_bytes()
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100178{
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100179 return get_or_create_component(*this, _offset_first_element_in_bytes, TensorComponentType::OffsetFirstElement);
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100180}
181
182// =================================================================================================
183// TensorComponentOperand
184// =================================================================================================
185
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100186TensorComponentOperand::TensorComponentOperand(TensorOperand &tensor, TensorComponentType component)
187 : TileOperand(tensor.name(), DataType::Int32), _tensor(tensor), _component(component)
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100188{
189}
190
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100191TensorOperand &TensorComponentOperand::tensor()
192{
193 return _tensor;
194}
195
196const TensorOperand &TensorComponentOperand::tensor() const
197{
198 return _tensor;
199}
200
201TensorComponentType TensorComponentOperand::component_type() const
202{
203 return _component;
204}
205
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100206prototype::Operand TensorComponentOperand::create_impl_operand(prototype::IGpuKernelWriter *writer) const
207{
208 CKW_UNUSED(writer);
209 prototype::OperandType type{ prototype::OperandType::Unknown };
210
211 switch(_component)
212 {
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100213 case TensorComponentType::OffsetFirstElement:
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100214 type = prototype::OperandType::TensorDataOffset;
215 break;
216
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100217 case TensorComponentType::Stride1:
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100218 type = prototype::OperandType::TensorStride1;
219 break;
220
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100221 case TensorComponentType::Stride2:
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100222 type = prototype::OperandType::TensorStride2;
223 break;
224
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100225 case TensorComponentType::Stride3:
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100226 type = prototype::OperandType::TensorStride3;
227 break;
228
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100229 case TensorComponentType::Stride4:
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100230 type = prototype::OperandType::TensorStride4;
231 break;
232
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100233 case TensorComponentType::Dim0:
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100234 type = prototype::OperandType::TensorDim0;
235 break;
236
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100237 case TensorComponentType::Dim1:
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100238 type = prototype::OperandType::TensorDim1;
239 break;
240
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100241 case TensorComponentType::Dim2:
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100242 type = prototype::OperandType::TensorDim2;
243 break;
244
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100245 case TensorComponentType::Dim3:
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100246 type = prototype::OperandType::TensorDim3;
247 break;
248
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100249 case TensorComponentType::Dim4:
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100250 type = prototype::OperandType::TensorDim4;
251 break;
252
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100253 case TensorComponentType::Dim1xDim2:
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100254 type = prototype::OperandType::TensorDim1xDim2;
255 break;
256
Viet-Hoa Doc8e16172023-06-27 14:09:46 +0100257 case TensorComponentType::Dim1xDim2xDim3:
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +0100258 type = prototype::OperandType::TensorDim1xDim2xDim3;
259 break;
260
261 default:
262 CKW_ASSERT(false);
263 }
264
265 return prototype::Operand(name(), type);
266}
267
268} // namespace ckw