blob: 24836092d4b65ff7f2afd94e5695fe4a6026000b [file] [log] [blame]
Gian Marco Iodice6c113ed2023-01-19 17:14:26 +00001/*
2 * Copyright (c) 2023 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +010025#include "src/TensorUtils.h"
Gian Marco Iodice6c113ed2023-01-19 17:14:26 +000026#include "ckw/Error.h"
27#include "ckw/TensorInfo.h"
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +010028#include "ckw/types/TensorComponentType.h"
Gian Marco Iodice6c113ed2023-01-19 17:14:26 +000029
30namespace ckw
31{
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +010032TensorComponentType get_tensor_dimension(TensorDataLayout layout, TensorDataLayoutComponent component)
Gian Marco Iodice6c113ed2023-01-19 17:14:26 +000033{
34 switch(layout)
35 {
36 case TensorDataLayout::Nhwc:
37 switch(component)
38 {
39 case TensorDataLayoutComponent::C:
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +010040 return TensorComponentType::Dim0;
Gian Marco Iodice6c113ed2023-01-19 17:14:26 +000041 case TensorDataLayoutComponent::W:
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +010042 return TensorComponentType::Dim1;
Gian Marco Iodice6c113ed2023-01-19 17:14:26 +000043 case TensorDataLayoutComponent::H:
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +010044 return TensorComponentType::Dim2;
Gian Marco Iodice6c113ed2023-01-19 17:14:26 +000045 case TensorDataLayoutComponent::N:
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +010046 return TensorComponentType::Dim3;
Gian Marco Iodice6c113ed2023-01-19 17:14:26 +000047 default:
48 COMPUTE_KERNEL_WRITER_ERROR_ON_MSG("Unsupported tensor component for NHWC");
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +010049 return TensorComponentType::Unknown;
Gian Marco Iodice6c113ed2023-01-19 17:14:26 +000050 }
51 case TensorDataLayout::Ndhwc:
52 switch(component)
53 {
54 case TensorDataLayoutComponent::C:
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +010055 return TensorComponentType::Dim0;
Gian Marco Iodice6c113ed2023-01-19 17:14:26 +000056 case TensorDataLayoutComponent::W:
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +010057 return TensorComponentType::Dim1;
Gian Marco Iodice6c113ed2023-01-19 17:14:26 +000058 case TensorDataLayoutComponent::H:
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +010059 return TensorComponentType::Dim2;
Gian Marco Iodice6c113ed2023-01-19 17:14:26 +000060 case TensorDataLayoutComponent::D:
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +010061 return TensorComponentType::Dim3;
Gian Marco Iodice6c113ed2023-01-19 17:14:26 +000062 case TensorDataLayoutComponent::N:
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +010063 return TensorComponentType::Dim4;
Gian Marco Iodice6c113ed2023-01-19 17:14:26 +000064 default:
65 COMPUTE_KERNEL_WRITER_ERROR_ON_MSG("Unsupported tensor component for NDHWC");
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +010066 return TensorComponentType::Unknown;
Gian Marco Iodice6c113ed2023-01-19 17:14:26 +000067 }
68 default:
69 COMPUTE_KERNEL_WRITER_ERROR_ON_MSG("Unsupported tensor data layout");
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +010070 return TensorComponentType::Unknown;
Gian Marco Iodice6c113ed2023-01-19 17:14:26 +000071 }
72}
73
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +010074TensorComponentType get_tensor_stride(TensorDataLayout layout, TensorDataLayoutComponent component)
Gian Marco Iodice6c113ed2023-01-19 17:14:26 +000075{
76 switch(layout)
77 {
78 case TensorDataLayout::Nhwc:
79 switch(component)
80 {
81 case TensorDataLayoutComponent::C:
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +010082 return TensorComponentType::Stride0;
Gian Marco Iodice6c113ed2023-01-19 17:14:26 +000083 case TensorDataLayoutComponent::W:
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +010084 return TensorComponentType::Stride1;
Gian Marco Iodice6c113ed2023-01-19 17:14:26 +000085 case TensorDataLayoutComponent::H:
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +010086 return TensorComponentType::Stride2;
Gian Marco Iodice6c113ed2023-01-19 17:14:26 +000087 case TensorDataLayoutComponent::N:
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +010088 return TensorComponentType::Stride3;
Gian Marco Iodice6c113ed2023-01-19 17:14:26 +000089 default:
90 COMPUTE_KERNEL_WRITER_ERROR_ON_MSG("Unsupported tensor component for NHWC");
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +010091 return TensorComponentType::Unknown;
Gian Marco Iodice6c113ed2023-01-19 17:14:26 +000092 }
93 case TensorDataLayout::Ndhwc:
94 switch(component)
95 {
96 case TensorDataLayoutComponent::C:
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +010097 return TensorComponentType::Stride0;
Gian Marco Iodice6c113ed2023-01-19 17:14:26 +000098 case TensorDataLayoutComponent::W:
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +010099 return TensorComponentType::Stride1;
Gian Marco Iodice6c113ed2023-01-19 17:14:26 +0000100 case TensorDataLayoutComponent::H:
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +0100101 return TensorComponentType::Stride2;
Gian Marco Iodice6c113ed2023-01-19 17:14:26 +0000102 case TensorDataLayoutComponent::D:
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +0100103 return TensorComponentType::Stride3;
Gian Marco Iodice6c113ed2023-01-19 17:14:26 +0000104 case TensorDataLayoutComponent::N:
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +0100105 return TensorComponentType::Stride4;
Gian Marco Iodice6c113ed2023-01-19 17:14:26 +0000106 default:
107 COMPUTE_KERNEL_WRITER_ERROR_ON_MSG("Unsupported tensor component for NDHWC");
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +0100108 return TensorComponentType::Unknown;
Gian Marco Iodice6c113ed2023-01-19 17:14:26 +0000109 }
110 default:
111 COMPUTE_KERNEL_WRITER_ERROR_ON_MSG("Unsupported tensor data layout");
Gian Marco Iodiceebfdb5a2023-07-07 11:25:57 +0100112 return TensorComponentType::Unknown;
Gian Marco Iodice6c113ed2023-01-19 17:14:26 +0000113 }
114}
115} // namespace ckw