blob: c29b307748e3c2554bef71416b4d486b60198687 [file] [log] [blame]
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +01001/*
2 * Copyright (c) 2023 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#include "src/cl/CLTensorComponent.h"
26#include "ckw/Error.h"
27#include "ckw/types/TensorComponentType.h"
28#include "src/cl/CLTensorArgument.h"
29#include "src/cl/CLTile.h"
30
31namespace ckw
32{
33
34namespace
35{
36
37std::string create_component_name(const std::string &name, TensorComponentType x)
38{
39 std::string var_name(name);
40
41 switch(x)
42 {
43 case TensorComponentType::OffsetFirstElement:
44 var_name += "_offset_first_element";
45 break;
46 case TensorComponentType::Stride0:
47 var_name += "_stride0";
48 break;
49 case TensorComponentType::Stride1:
50 var_name += "_stride1";
51 break;
52 case TensorComponentType::Stride2:
53 var_name += "_stride2";
54 break;
55 case TensorComponentType::Stride3:
56 var_name += "_stride3";
57 break;
58 case TensorComponentType::Stride4:
59 var_name += "_stride4";
60 break;
61 case TensorComponentType::Dim0:
62 var_name += "_dim0";
63 break;
64 case TensorComponentType::Dim1:
65 var_name += "_dim1";
66 break;
67 case TensorComponentType::Dim2:
68 var_name += "_dim2";
69 break;
70 case TensorComponentType::Dim3:
71 var_name += "_dim3";
72 break;
73 case TensorComponentType::Dim4:
74 var_name += "_dim4";
75 break;
76 case TensorComponentType::Dim1xDim2:
77 var_name += "_dim1xdim2";
78 break;
79 case TensorComponentType::Dim2xDim3:
80 var_name += "_dim2xdim3";
81 break;
82 case TensorComponentType::Dim1xDim2xDim3:
83 var_name += "_dim1xdim2xdim3";
84 break;
85 default:
86 CKW_THROW_MSG("Unsupported tensor component");
87 return "";
88 }
89
90 return var_name;
91}
92
93} // namespace
94
95CLTensorComponent::CLTensorComponent(const CLTensorArgument &tensor, TensorComponentType component_type)
96 : CLTile(create_component_name(tensor.name(), component_type), TileInfo(DataType::Int32)), _component_type(component_type)
97{
98}
99
100CLTensorComponent::CLTensorComponent(const CLTensorArgument &tensor, TensorComponentType component_type, int32_t value)
101 : CLTile({ { std::to_string(value) } }, DataType::Int32), _component_type(component_type)
102{
103 CKW_UNUSED(tensor);
104}
105
106CLTensorComponent::~CLTensorComponent() = default;
107
108ITile &CLTensorComponent::tile()
109{
110 return *this;
111}
112
113const ITile &CLTensorComponent::tile() const
114{
115 return *this;
116}
117
118TensorComponentType CLTensorComponent::component_type() const
119{
120 return _component_type;
121}
122
123} // namespace ckw