blob: 52773faa3a66eb00bcff67ce3aabb08b067c238f [file] [log] [blame]
Georgios Pinitas358ca202017-12-07 16:47:52 +00001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef __ARM_COMPUTE_MISC_SHAPE_CALCULATOR_H__
25#define __ARM_COMPUTE_MISC_SHAPE_CALCULATOR_H__
26
27#include "arm_compute/core/ITensorInfo.h"
28
29namespace arm_compute
30{
31namespace misc
32{
33namespace shape_calculator
34{
35inline TensorShape compute_interleaved_shape(const ITensorInfo &a)
36{
37 // The interleaved output matrix will have the following shape: [ a_height * 4, ceil(a_width / 4.0f) ]
38 TensorShape shape_interleaved_a{ a.tensor_shape() };
39 shape_interleaved_a.set(0, a.dimension(0) * 4);
40 shape_interleaved_a.set(1, std::ceil(a.dimension(1) / 4.f));
41
42 return shape_interleaved_a;
43}
44inline TensorShape compute_transpose1xW_shape(const ITensorInfo &b)
45{
46 // The transpose1xW output matrix will have the following shape: [ b_height * 16, ceil(b_width / 16.0f) ]
47 TensorShape shape_transposed1xW_b{ b.tensor_shape() };
48 shape_transposed1xW_b.set(0, b.dimension(1) * 16);
49 shape_transposed1xW_b.set(1, std::ceil(b.dimension(0) / 16.f));
50
51 return shape_transposed1xW_b;
52}
53inline TensorShape compute_transpose1xW_with_element_size_shape(const ITensorInfo &b)
54{
55 // The transpose1xW output matrix will have the following shape:
56 // [ b_height * (16 / element_size), ceil(b_width / (16.0f / element_size) ]
57 TensorShape shape_transposed1xW_b{ b.tensor_shape() };
58 const size_t transpose_width = 16 / b.element_size();
59 shape_transposed1xW_b.set(0, b.dimension(1) * transpose_width);
60 shape_transposed1xW_b.set(1, static_cast<size_t>(std::ceil(b.dimension(0) / static_cast<float>(transpose_width))));
61
62 return shape_transposed1xW_b;
63}
64inline TensorShape compute_reductionA_shape(const ITensorInfo &b)
65{
66 TensorShape shape_vector_sum_col{ b.tensor_shape() };
67 if(shape_vector_sum_col.num_dimensions() > 1)
68 {
69 shape_vector_sum_col.remove_dimension(1);
70 }
71
72 return shape_vector_sum_col;
73}
74inline TensorShape compute_reductionB_shape(const ITensorInfo &a)
75{
76 TensorShape shape_vector_sum_row{ a.tensor_shape() };
77 shape_vector_sum_row.set(Window::DimX, a.dimension(1));
78 if(a.num_dimensions() > 1)
79 {
80 shape_vector_sum_row.remove_dimension(1);
81 }
82
83 return shape_vector_sum_row;
84}
85inline TensorShape compute_im2col_shape(const ITensorInfo &input)
86{
87 TensorShape shape_im2col{ input.tensor_shape() };
88 shape_im2col.collapse(3);
89
90 return shape_im2col;
91}
92inline TensorShape compute_transposed_shape(const ITensorInfo &input)
93{
94 TensorShape shape_transposed{ input.tensor_shape() };
95
96 shape_transposed.set(0, input.dimension(1));
97 shape_transposed.set(1, input.dimension(0));
98
99 return shape_transposed;
100}
101} // namespace shape_calculator
102} // namespace misc
103} // namespace arm_compute
104#endif /* __ARM_COMPUTE_MISC_SHAPE_CALCULATOR_H__ */