blob: 9f04411d929187f9c1f707f20ce5b0efc972592f [file] [log] [blame]
Anthony Barbier7068f992017-10-26 15:23:08 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
Gian Marcoff850932017-12-11 12:37:17 +000025#ifndef __ARM_COMPUTE_GCDROPOUTLAYERKERNEL_H__
26#define __ARM_COMPUTE_GCDROPOUTLAYERKERNEL_H__
Anthony Barbier7068f992017-10-26 15:23:08 +010027
28#include "arm_compute/core/GLES_COMPUTE/IGCKernel.h"
29
30namespace arm_compute
31{
32class IGCTensor;
33
Gian Marcoff850932017-12-11 12:37:17 +000034/** Interface for the dropout layer kernel.
Anthony Barbier7068f992017-10-26 15:23:08 +010035 *
36 * Dropout is used to improve over-fit on neural networks.
37 *
38 */
Gian Marcoff850932017-12-11 12:37:17 +000039class GCDropoutLayerKernel : public IGCKernel
Anthony Barbier7068f992017-10-26 15:23:08 +010040{
41public:
42 /** Default constructor */
Gian Marcoff850932017-12-11 12:37:17 +000043 GCDropoutLayerKernel();
Anthony Barbier7068f992017-10-26 15:23:08 +010044
45 /** Prevent instances of this class from being copied (As this class contains pointers) */
Gian Marcoff850932017-12-11 12:37:17 +000046 GCDropoutLayerKernel(const GCDropoutLayerKernel &) = delete;
Anthony Barbier7068f992017-10-26 15:23:08 +010047
48 /** Prevent instances of this class from being copied (As this class contains pointers) */
Gian Marcoff850932017-12-11 12:37:17 +000049 GCDropoutLayerKernel &operator=(const GCDropoutLayerKernel &) = delete;
Anthony Barbier7068f992017-10-26 15:23:08 +010050
51 /** Allow instances of this class to be moved */
Gian Marcoff850932017-12-11 12:37:17 +000052 GCDropoutLayerKernel(GCDropoutLayerKernel &&) = default;
Anthony Barbier7068f992017-10-26 15:23:08 +010053
54 /** Allow instances of this class to be moved */
Gian Marcoff850932017-12-11 12:37:17 +000055 GCDropoutLayerKernel &operator=(GCDropoutLayerKernel &&) = default;
Anthony Barbier7068f992017-10-26 15:23:08 +010056
57 /** Set the input and output of the kernel.
58 *
59 * @param[in] input The input tensor for this op. Data types supported: F16/F32
60 * @param[out] mask The mask tensor. Data types supported: Same as @p input
61 * @param[out] output The output tensor. Data types supported: Same as @p input
62 * @param[in] ratio Dropout ratio
63 * @param[in] forward Forward or backward propagation
64 *
65 */
66 void configure(const IGCTensor *input, IGCTensor *mask, IGCTensor *output, float ratio, bool forward);
67
68 // Inherited methods overridden:
69 void run(const Window &window) override;
70
71private:
72 const IGCTensor *_input;
73 IGCTensor *_mask;
74 IGCTensor *_output;
75 unsigned int _num_elems_processed_per_iteration;
76};
77}
78
Gian Marcoff850932017-12-11 12:37:17 +000079#endif /*__ARM_COMPUTE_GCDROPOUTLAYERKERNEL_H__ */