blob: 5d49901dd0306522f377cece1c88a786727438a8 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2016, 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef __ARM_COMPUTE_NEMAGNITUDEPHASEKERNEL_H__
25#define __ARM_COMPUTE_NEMAGNITUDEPHASEKERNEL_H__
26
27#include "arm_compute/core/NEON/INEKernel.h"
28#include "arm_compute/core/Types.h"
29
30namespace arm_compute
31{
32class ITensor;
33
34/** Template interface for the kernel to compute magnitude and phase */
35template <MagnitudeType mag_type, PhaseType phase_type>
36class NEMagnitudePhaseKernel : public INEKernel
37{
38public:
39 /** Default constructor */
40 NEMagnitudePhaseKernel();
41 /** Destructor */
42 ~NEMagnitudePhaseKernel() = default;
43 /** Prevent instances of this class from being copied (As this class contains pointers) */
44 NEMagnitudePhaseKernel(const NEMagnitudePhaseKernel &) = delete;
45 /** Default move constructor */
46 NEMagnitudePhaseKernel(NEMagnitudePhaseKernel &&) = default;
47 /** Prevent instances of this class from being copied (As this class contains pointers) */
48 NEMagnitudePhaseKernel &operator=(const NEMagnitudePhaseKernel &) = delete;
49 /** Default move assignment operator */
50 NEMagnitudePhaseKernel &operator=(NEMagnitudePhaseKernel &&) = default;
51
52 /** Initialise the kernel's input, output.
53 *
54 * @note At least one of out1 or out2 must be set
55 *
56 * @param[in] gx Gradient X tensor. Data type supported: S16.
57 * @param[in] gy Gradient Y tensor. Data type supported: S16.
58 * @param[out] magnitude (Optional) The output tensor - Magnitude. Data type supported: S16.
59 * @param[out] phase (Optional) The output tensor - Phase. Data type supported: U8.
60 */
61 void configure(const ITensor *gx, const ITensor *gy, ITensor *magnitude, ITensor *phase);
62
63 // Inherited methods overridden:
64 void run(const Window &window) override;
65
66private:
67 /** Function to perform magnitude on the given window
68 *
69 * @param[in] window Region on which to execute the kernel
70 */
71 void magnitude(const Window &window);
72 /** Function to perform phase on the given window
73 *
74 * @param[in] window Region on which to execute the kernel
75 */
76 void phase(const Window &window);
77 /** Function to perform magnitude and phase on the given window
78 *
79 * @param[in] window Region on which to execute the kernel
80 */
81 void magnitude_phase(const Window &window);
82
83private:
84 /** Common signature for all the specialised MagnitudePhase functions
85 *
86 * @param[in] window Region on which to execute the kernel.
87 */
88 using MagnitudePhaseFunctionPtr = void (NEMagnitudePhaseKernel::*)(const Window &window);
89 /** MagnitudePhase function to use for the particular formats passed to configure() */
90 MagnitudePhaseFunctionPtr _func;
91 const ITensor *_gx; /**< Input gradient X */
92 const ITensor *_gy; /**< Input gradient Y */
93 ITensor *_magnitude; /**< Output - Magnitude */
94 ITensor *_phase; /**< Output - Phase */
95};
96
97#ifdef ARM_COMPUTE_ENABLE_FP16
98/** Template interface for the kernel to compute magnitude and phase */
99template <MagnitudeType mag_type, PhaseType phase_type>
100class NEMagnitudePhaseFP16Kernel : public INEKernel
101{
102public:
103 /** Default constructor */
104 NEMagnitudePhaseFP16Kernel();
105 /** Destructor */
106 ~NEMagnitudePhaseFP16Kernel() = default;
107 /** Prevent instances of this class from being copied (As this class contains pointers) */
108 NEMagnitudePhaseFP16Kernel(const NEMagnitudePhaseFP16Kernel &) = delete;
109 /** Default move constructor */
110 NEMagnitudePhaseFP16Kernel(NEMagnitudePhaseFP16Kernel &&) = default;
111 /** Prevent instances of this class from being copied (As this class contains pointers) */
112 NEMagnitudePhaseFP16Kernel &operator=(const NEMagnitudePhaseFP16Kernel &) = delete;
113 /** Default move assignment operator */
114 NEMagnitudePhaseFP16Kernel &operator=(NEMagnitudePhaseFP16Kernel &&) = default;
115
116 /** Initialise the kernel's input, output.
117 *
118 * @note At least one of out1 or out2 must be set
119 *
120 * @param[in] gx Gradient X tensor. Data type supported: S16.
121 * @param[in] gy Gradient Y tensor. Data type supported: S16.
122 * @param[out] magnitude (Optional) The output tensor - Magnitude. Data type supported: S16.
123 * @param[out] phase (Optional) The output tensor - Phase. Data type supported: U8.
124 */
125 void configure(const ITensor *gx, const ITensor *gy, ITensor *magnitude, ITensor *phase);
126
127 // Inherited methods overridden:
128 void run(const Window &window) override;
129
130private:
131 /** Function to perform magnitude on the given window
132 *
133 * @param[in] window Region on which to execute the kernel
134 */
135 void magnitude(const Window &window);
136 /** Function to perform phase on the given window
137 *
138 * @param[in] window Region on which to execute the kernel
139 */
140 void phase(const Window &window);
141 /** Function to perform magnitude and phase on the given window
142 *
143 * @param[in] window Region on which to execute the kernel
144 */
145 void magnitude_phase(const Window &window);
146
147 /** Common signature for all the specialised MagnitudePhase functions
148 *
149 * @param[in] window Region on which to execute the kernel.
150 */
151 using MagnitudePhaseFunctionPtr = void (NEMagnitudePhaseFP16Kernel::*)(const Window &window);
152 /** MagnitudePhase function to use for the particular formats passed to configure() */
153 MagnitudePhaseFunctionPtr _func;
154 const ITensor *_gx; /**< Input gradient X */
155 const ITensor *_gy; /**< Input gradient Y */
156 ITensor *_magnitude; /**< Output - Magnitude */
157 ITensor *_phase; /**< Output - Phase */
158};
159#else
160template <MagnitudeType mag_type, PhaseType phase_type>
161using NEMagnitudePhaseFP16Kernel = NEMagnitudePhaseKernel<mag_type, phase_type>;
162#endif
163}
164#endif /* __ARM_COMPUTE_NEMAGNITUDEPHASEKERNEL_H__ */