blob: dbe2036768a8e5b68a79668261efa220bbe55120 [file] [log] [blame]
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +01001/*
2 * Copyright (c) 2023 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#include "src/cl/CLTensorComponent.h"
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010026
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +010027#include "ckw/Error.h"
28#include "ckw/types/TensorComponentType.h"
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010029
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +010030#include "src/cl/CLTensorArgument.h"
31#include "src/cl/CLTile.h"
32
33namespace ckw
34{
35
36namespace
37{
38
39std::string create_component_name(const std::string &name, TensorComponentType x)
40{
41 std::string var_name(name);
42
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010043 switch (x)
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +010044 {
45 case TensorComponentType::OffsetFirstElement:
46 var_name += "_offset_first_element";
47 break;
48 case TensorComponentType::Stride0:
49 var_name += "_stride0";
50 break;
51 case TensorComponentType::Stride1:
52 var_name += "_stride1";
53 break;
54 case TensorComponentType::Stride2:
55 var_name += "_stride2";
56 break;
57 case TensorComponentType::Stride3:
58 var_name += "_stride3";
59 break;
60 case TensorComponentType::Stride4:
61 var_name += "_stride4";
62 break;
63 case TensorComponentType::Dim0:
64 var_name += "_dim0";
65 break;
66 case TensorComponentType::Dim1:
67 var_name += "_dim1";
68 break;
69 case TensorComponentType::Dim2:
70 var_name += "_dim2";
71 break;
72 case TensorComponentType::Dim3:
73 var_name += "_dim3";
74 break;
75 case TensorComponentType::Dim4:
76 var_name += "_dim4";
77 break;
78 case TensorComponentType::Dim1xDim2:
79 var_name += "_dim1xdim2";
80 break;
81 case TensorComponentType::Dim2xDim3:
82 var_name += "_dim2xdim3";
83 break;
84 case TensorComponentType::Dim1xDim2xDim3:
85 var_name += "_dim1xdim2xdim3";
86 break;
87 default:
88 CKW_THROW_MSG("Unsupported tensor component");
89 return "";
90 }
91
92 return var_name;
93}
94
95} // namespace
96
97CLTensorComponent::CLTensorComponent(const CLTensorArgument &tensor, TensorComponentType component_type)
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010098 : CLTile(create_component_name(tensor.name(), component_type), TileInfo(DataType::Int32)),
99 _component_type(component_type)
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +0100100{
101}
102
103CLTensorComponent::CLTensorComponent(const CLTensorArgument &tensor, TensorComponentType component_type, int32_t value)
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100104 : CLTile({{std::to_string(value)}}, DataType::Int32), _component_type(component_type)
Viet-Hoa Do0b23e0e2023-07-25 14:00:46 +0100105{
106 CKW_UNUSED(tensor);
107}
108
109CLTensorComponent::~CLTensorComponent() = default;
110
111ITile &CLTensorComponent::tile()
112{
113 return *this;
114}
115
116const ITile &CLTensorComponent::tile() const
117{
118 return *this;
119}
120
121TensorComponentType CLTensorComponent::component_type() const
122{
123 return _component_type;
124}
125
126} // namespace ckw