blob: 7d4dc958df6d82477e696458addc1d19f1bd13ff [file] [log] [blame]
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +01001/*
2 * Copyright (c) 2023 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#include "src/cl/CLTensorArgument.h"
26#include "ckw/Error.h"
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +010027#include "src/ITensorArgument.h"
28#include "src/ITensorComponent.h"
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +010029#include "src/cl/CLHelpers.h"
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +010030#include "src/cl/CLTensorComponent.h"
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +010031#include "src/types/TensorComponentType.h"
32
33#include <algorithm>
34#include <vector>
35
36namespace ckw
37{
38CLTensorArgument::CLTensorArgument(const std::string &name, const TensorInfo &info, bool return_dims_by_value)
39{
40 _return_dims_by_value = return_dims_by_value;
41 _basename = name;
42 _info = info;
43}
44
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +010045CLTensorArgument::~CLTensorArgument() = default;
46
47CLTensorComponent &CLTensorArgument::cl_component(TensorComponentType x)
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +010048{
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +010049 // Return the component if it has already been created.
50 {
51 const auto it = std::find_if(
52 _components_used.begin(), _components_used.end(),
53 [=](const std::unique_ptr<CLTensorComponent> &item)
54 {
55 return item->component_type() == x;
56 });
57
58 if(it != _components_used.end())
59 {
60 return **it;
61 }
62 }
63
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +010064 if(_return_dims_by_value)
65 {
66 uint32_t component_type = static_cast<uint32_t>(x);
67
68 const bool is_dimension = (component_type & static_cast<uint32_t>(TensorComponentBitmask::Dimension)) != 0;
69 const bool is_folded_dimensions = (component_type & static_cast<uint32_t>(TensorComponentBitmask::FoldedDimensions)) != 0;
70
71 constexpr auto bitmask_all = static_cast<uint32_t>(TensorComponentIndexBitmask::All);
72 constexpr auto bitmask_index_0 = static_cast<uint32_t>(TensorComponentIndexBitmask::Index0);
73#ifdef COMPUTE_KERNEL_WRITER_ASSERTS_ENABLED
74 constexpr auto bitmask_index_1 = static_cast<uint32_t>(TensorComponentIndexBitmask::Index1);
75 constexpr auto bitmask_index_2 = static_cast<uint32_t>(TensorComponentIndexBitmask::Index2);
76 constexpr auto bitmask_index_3 = static_cast<uint32_t>(TensorComponentIndexBitmask::Index3);
77#endif // COMPUTE_KERNEL_WRITER_ASSERTS_ENABLED
78
79 // Make sure that the encoding of component type hasn't changed and each nibble is 4 bits apart.
80 CKW_ASSERT(bitmask_all == (bitmask_index_0 | bitmask_index_1 | bitmask_index_2 | bitmask_index_3));
81 CKW_ASSERT(bitmask_index_0 == bitmask_index_1 >> 4);
82 CKW_ASSERT(bitmask_index_1 == bitmask_index_2 >> 4);
83 CKW_ASSERT(bitmask_index_2 == bitmask_index_3 >> 4);
84
85 // If we have a dimension or folded dimensions, we can return the corresponding value if it is not dynamic (not equal to -1)
86 if(is_dimension == true || is_folded_dimensions == true)
87 {
88 component_type = component_type & bitmask_all;
89
90 int32_t idx = 1;
91 for(int32_t i = 0; i < tensor_component_index_max_count; ++i)
92 {
93 uint32_t dim_idx = component_type & bitmask_index_0;
94
95 if(dim_idx == 0)
96 {
97 // Stop at the first nibble containing 0
98 break;
99 }
100
101 // Subtract - 1. Please refer to the TensorComponentIndexBitmask documentation
102 dim_idx -= 1;
103
104 // Get the dimension value
105 const int32_t dim_val = _info.shape()[dim_idx];
106
107 if(dim_val == kDynamicTensorDimensionValue)
108 {
109 // We cannot return the dimension by value if it is dynamic.
110 // Therefore, force the idx variable to kDynamicTensorDimensionValue and break the loop.
111 idx = kDynamicTensorDimensionValue;
112 break;
113 }
114
115 idx *= dim_val;
116
117 // Go to the next nibble
118 component_type >>= 4;
119 }
120
121 if(idx != kDynamicTensorDimensionValue)
122 {
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +0100123 _components_used.emplace_back(std::make_unique<CLTensorComponent>(*this, x, idx));
124
125 return *_components_used.back();
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +0100126 }
127 }
128 }
129
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +0100130 _components_used.emplace_back(std::make_unique<CLTensorComponent>(*this, x));
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +0100131
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +0100132 return *_components_used.back();
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +0100133}
134
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +0100135ITile &CLTensorArgument::component(TensorComponentType x)
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +0100136{
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +0100137 return cl_component(x);
138}
139
140TensorStorageVariable &CLTensorArgument::storage(TensorStorageType x)
141{
142 // Return the storage if it has already been created.
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +0100143 {
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +0100144 const auto it = std::find_if(
145 _storages_used.begin(), _storages_used.end(),
146 [=](const TensorStorageVariable &item)
147 {
148 return item.type == x;
149 });
150
151 if(it != _storages_used.end())
152 {
153 return *it;
154 }
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +0100155 }
156
157 TensorStorageVariable t;
158 t.val = create_storage_name(x);
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +0100159 t.type = x;
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +0100160
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +0100161 _storages_used.emplace_back(t);
162
163 return _storages_used.back();
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +0100164}
165
166std::string CLTensorArgument::create_storage_name(TensorStorageType x) const
167{
168 std::string var_name = _basename;
169
170 switch(x)
171 {
172 case TensorStorageType::BufferUint8Ptr:
173 var_name += "_ptr";
174 break;
175 case TensorStorageType::Texture2dReadOnly:
176 case TensorStorageType::Texture2dWriteOnly:
177 var_name += "_img2d";
178 break;
179 default:
180 CKW_ASSERT_FAILED_MSG("Unsupported tensor storage");
181 return "";
182 }
183
184 return var_name;
185}
186
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +0100187std::vector<TensorStorageVariable> CLTensorArgument::storages() const
188{
189 std::vector<TensorStorageVariable> storages;
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +0100190 storages.reserve(_storages_used.size());
191
192 std::copy(_storages_used.begin(), _storages_used.end(), std::back_inserter(storages));
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +0100193
194 return storages;
195}
196
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +0100197std::vector<const ITensorComponent *> CLTensorArgument::components() const
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +0100198{
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +0100199 std::vector<const ITensorComponent *> components;
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +0100200
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +0100201 for(const auto &component : _components_used)
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +0100202 {
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +0100203 if(component->is_assignable())
204 {
205 components.push_back(component.get());
206 }
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +0100207 }
208
209 return components;
210}
211} // namespace ckw