blob: a3a10abd7f6d0c1da81e34ef5d6c4abec4435fd0 [file] [log] [blame]
Georgios Pinitasdef2a852019-02-21 14:47:56 +00001/*
2 * Copyright (c) 2019 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef __ARM_COMPUTE_TEST_FFT_H__
25#define __ARM_COMPUTE_TEST_FFT_H__
26
27#include "tests/SimpleTensor.h"
28#include "tests/validation/Helpers.h"
29
30namespace arm_compute
31{
32namespace test
33{
34namespace validation
35{
36namespace reference
37{
38enum class FFTDirection
39{
40 Forward,
41 Inverse
42};
43
44/** Performs an one dimensional DFT on a real input.
45 *
46 * @param[in] src Source tensor.
47 *
48 * @return Complex output of length n/2 + 1 due to symmetry.
49 */
50template <typename T>
51SimpleTensor<T> rdft_1d(const SimpleTensor<T> &src);
52
53/** Performs an one dimensional inverse DFT on a real input.
54 *
55 * @param[in] src Source tensor.
56 * @param[in] is_odd (Optional) Specifies if the output has odd dimensions.
57 * Is used by the inverse variant to reconstruct odd sequences.
58 *
59 * @return Complex output of length n/2 + 1 due to symmetry.
60 */
61template <typename T>
62SimpleTensor<T> ridft_1d(const SimpleTensor<T> &src, bool is_odd = false);
63
64/** Performs an one dimensional DFT on a complex input.
65 *
66 * @param[in] src Source tensor.
67 * @param[in] direction Direction of the DFT.
68 *
69 * @return Complex output of same length as input.
70 */
71template <typename T>
72SimpleTensor<T> dft_1d(const SimpleTensor<T> &src, FFTDirection direction);
73
74/** Performs a two dimensional DFT on a real input.
75 *
76 * @param[in] src Source tensor.
77 *
78 * @return Complex output of length n/2 + 1 across width due to symmetry and height of same size as the input.
79 */
80template <typename T>
81SimpleTensor<T> rdft_2d(const SimpleTensor<T> &src);
82
83/** Performs a two dimensional inverse DFT on a real input.
84 *
85 * @param[in] src Source tensor.
86 * @param[in] is_odd (Optional) Specifies if the output has odd dimensions across width.
87 * Is used by the inverse variant to reconstruct odd sequences.
88 *
89 * @return Complex output of length n/2 + 1 across width due to symmetry and height of same size as the input.
90 */
91template <typename T>
92SimpleTensor<T> ridft_2d(const SimpleTensor<T> &src, bool is_odd = false);
93
94/** Performs a two dimensional DFT on a complex input.
95 *
96 * @param[in] src Source tensor.
97 * @param[in] direction Direction of the DFT.
98 *
99 * @return Complex output of same length as input.
100 */
101template <typename T>
102SimpleTensor<T> dft_2d(const SimpleTensor<T> &src, FFTDirection direction);
103
104/** Performs and DFT based convolution on a real input.
105 *
106 * @param[in] src Source tensor.
107 * @param[in] w Weights tensor.
108 * @param[in] conv_info Convolution related metadata.
109 *
110 * @return The output tensor.
111 */
112template <typename T>
113SimpleTensor<T> conv2d_dft(const SimpleTensor<T> &src, const SimpleTensor<T> &w, const PadStrideInfo &conv_info);
114} // namespace reference
115} // namespace validation
116} // namespace test
117} // namespace arm_compute
118#endif /* __ARM_COMPUTE_TEST_FFT_H__ */