blob: 00ecc3824e62e62fbcbc4dad78c7865d9d9ba0ae [file] [log] [blame]
Viet-Hoa Dobd4f6b92023-05-30 09:34:32 +01001/*
2 * Copyright (c) 2023 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#include "ckw/TensorOperand.h"
26#include "ckw/Error.h"
27#include "ckw/Kernel.h"
28#include "ckw/TileOperand.h"
29#include "src/Prototype.h"
30
31namespace ckw
32{
33
34namespace
35{
36
37inline TensorComponentOperand &get_or_create_component(std::unique_ptr<TensorComponentOperand> &ptr, const ::std::string &name, TensorComponent component)
38{
39 if(ptr == nullptr)
40 {
41 ptr = std::make_unique<TensorComponentOperand>(name, component);
42 }
43
44 return *ptr;
45}
46
47} // namespace
48
49// =================================================================================================
50// TensorOperand
51// =================================================================================================
52
53TensorOperand::TensorOperand(const std::string &name, const TensorInfo &info)
54 : OperandBase(name), _info(info)
55{
56}
57
58prototype::Operand TensorOperand::create_impl_operand(prototype::IGpuKernelWriter *writer) const
59{
60 CKW_UNUSED(writer);
61 return { name() };
62}
63
64const TensorInfo &TensorOperand::info() const
65{
66 return _info;
67}
68
69TensorInfo &TensorOperand::info()
70{
71 return _info;
72}
73
74DataType TensorOperand::data_type() const
75{
76 return _info.data_type();
77}
78
79bool TensorOperand::is_constant() const
80{
81 return false;
82}
83
84const TileOperand &TensorOperand::tile() const
85{
86 return *_tile;
87}
88
89TileOperand &TensorOperand::tile()
90{
91 return *_tile;
92}
93
94TensorOperand &TensorOperand::tile(TileOperand &tile)
95{
96 _tile = &tile;
97 return *this;
98}
99
100const TensorTileSampler &TensorOperand::tile_sampler() const
101{
102 return _tile_sampler;
103}
104
105TensorTileSampler &TensorOperand::tile_sampler()
106{
107 return _tile_sampler;
108}
109
110TensorOperand &TensorOperand::tile_sampler(const TensorTileSampler &value)
111{
112 _tile_sampler = value;
113 return *this;
114}
115
116TileOperand &TensorOperand::stride1()
117{
118 return get_or_create_component(_stride1, name(), TensorComponent::Stride1);
119}
120
121TileOperand &TensorOperand::stride2()
122{
123 return get_or_create_component(_stride2, name(), TensorComponent::Stride2);
124}
125
126TileOperand &TensorOperand::stride3()
127{
128 return get_or_create_component(_stride3, name(), TensorComponent::Stride3);
129}
130
131TileOperand &TensorOperand::stride4()
132{
133 return get_or_create_component(_stride4, name(), TensorComponent::Stride4);
134}
135
136TileOperand &TensorOperand::dim0()
137{
138 return get_or_create_component(_dim0, name(), TensorComponent::Dim0);
139}
140
141TileOperand &TensorOperand::dim1()
142{
143 return get_or_create_component(_dim1, name(), TensorComponent::Dim1);
144}
145
146TileOperand &TensorOperand::dim2()
147{
148 return get_or_create_component(_dim2, name(), TensorComponent::Dim2);
149}
150
151TileOperand &TensorOperand::dim3()
152{
153 return get_or_create_component(_dim3, name(), TensorComponent::Dim3);
154}
155
156TileOperand &TensorOperand::dim4()
157{
158 return get_or_create_component(_dim4, name(), TensorComponent::Dim4);
159}
160
161TileOperand &TensorOperand::dim1_dim2()
162{
163 return get_or_create_component(_dim1_dim2, name(), TensorComponent::Dim1xDim2);
164}
165
166TileOperand &TensorOperand::dim1_dim2_dim3()
167{
168 return get_or_create_component(_dim1_dim2_dim3, name(), TensorComponent::Dim1xDim2xDim3);
169}
170
171TileOperand &TensorOperand::offset_first_element_in_bytes()
172{
173 return get_or_create_component(_offset_first_element_in_bytes, name(), TensorComponent::OffsetFirstElement);
174}
175
176// =================================================================================================
177// TensorComponentOperand
178// =================================================================================================
179
180TensorComponentOperand::TensorComponentOperand(const ::std::string &name, TensorComponent component)
181 : TileOperand(name, DataType::Int32), _component(component)
182{
183}
184
185prototype::Operand TensorComponentOperand::create_impl_operand(prototype::IGpuKernelWriter *writer) const
186{
187 CKW_UNUSED(writer);
188 prototype::OperandType type{ prototype::OperandType::Unknown };
189
190 switch(_component)
191 {
192 case TensorComponent::OffsetFirstElement:
193 type = prototype::OperandType::TensorDataOffset;
194 break;
195
196 case TensorComponent::Stride1:
197 type = prototype::OperandType::TensorStride1;
198 break;
199
200 case TensorComponent::Stride2:
201 type = prototype::OperandType::TensorStride2;
202 break;
203
204 case TensorComponent::Stride3:
205 type = prototype::OperandType::TensorStride3;
206 break;
207
208 case TensorComponent::Stride4:
209 type = prototype::OperandType::TensorStride4;
210 break;
211
212 case TensorComponent::Dim0:
213 type = prototype::OperandType::TensorDim0;
214 break;
215
216 case TensorComponent::Dim1:
217 type = prototype::OperandType::TensorDim1;
218 break;
219
220 case TensorComponent::Dim2:
221 type = prototype::OperandType::TensorDim2;
222 break;
223
224 case TensorComponent::Dim3:
225 type = prototype::OperandType::TensorDim3;
226 break;
227
228 case TensorComponent::Dim4:
229 type = prototype::OperandType::TensorDim4;
230 break;
231
232 case TensorComponent::Dim1xDim2:
233 type = prototype::OperandType::TensorDim1xDim2;
234 break;
235
236 case TensorComponent::Dim1xDim2xDim3:
237 type = prototype::OperandType::TensorDim1xDim2xDim3;
238 break;
239
240 default:
241 CKW_ASSERT(false);
242 }
243
244 return prototype::Operand(name(), type);
245}
246
247} // namespace ckw