blob: 071eecc3f74691eaf2d3ce95e3f0f78abaa58214 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +00002 * Copyright (c) 2017-2018 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef __ARM_COMPUTE_NEFULLYCONNECTEDLAYER_H__
25#define __ARM_COMPUTE_NEFULLYCONNECTEDLAYER_H__
26
27#include "arm_compute/runtime/IFunction.h"
28
29#include "arm_compute/core/NEON/kernels/NEGEMMInterleave4x4Kernel.h"
30#include "arm_compute/core/NEON/kernels/NEGEMMMatrixAccumulateBiasesKernel.h"
31#include "arm_compute/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.h"
32#include "arm_compute/core/NEON/kernels/NEGEMMTranspose1xWKernel.h"
33#include "arm_compute/core/NEON/kernels/NEIm2ColKernel.h"
34#include "arm_compute/core/NEON/kernels/NETransposeKernel.h"
Georgios Pinitasbaf174e2017-09-08 19:47:30 +010035#include "arm_compute/runtime/MemoryGroup.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010036#include "arm_compute/runtime/Tensor.h"
37
38namespace arm_compute
39{
40/** Basic function to reshape the weights of Fully Connected layer with NEON. This function calls the following kernels:
41 *
42 * -# @ref NETransposeKernel (if @p transpose_weights is set to true)
43 * -# @ref NEGEMMTranspose1xWKernel (if @p is_batched_fc_layer is set to true)
44 *
45 * @note The fully connected layer accepts "weights" tensors only with 2 dimensions.
46 */
47class NEFullyConnectedLayerReshapeWeights : public IFunction
48{
49public:
50 /** Constructor */
Georgios Pinitasbaf174e2017-09-08 19:47:30 +010051 NEFullyConnectedLayerReshapeWeights(std::shared_ptr<IMemoryManager> memory_manager = nullptr);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010052 /** Set the input and output tensors.
53 *
Gian Marco Iodice2bbd9642017-07-04 16:46:32 +010054 * @param[in] input Weights tensor. The weights must be 2 dimensional. Data types supported: QS8/QS16/F32.
Anthony Barbier6ff3b192017-09-04 18:44:23 +010055 * @param[out] output Destination tensor. Data type supported: Same as @p input.
56 * @param[in] transpose_weights True if the weights must be transposed. Data types supported: Same as @p weights.
57 * @param[in] is_batched_fc_layer True if it is a batched fully connected layer
58 */
59 void configure(const ITensor *input, ITensor *output, bool transpose_weights, bool is_batched_fc_layer);
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +000060 /** Static function to check if given info will lead to a valid configuration of @ref CLFullyConnectedLayerReshapeWeights
61 *
62 * @param[in] input Weights tensor info. The weights must be 2 dimensional. Data types supported: QS8/QS16/F32.
63 * @param[in] output Destination tensor info. Data type supported: Same as @p input.
64 * @param[in] transpose_weights True if the weights must be transposed. Data types supported: Same as @p weights.
65 * @param[in] is_batched_fc_layer True if it is a batched fully connected layer
66 *
67 * @return a status
68 */
69 static Status validate(const ITensorInfo *input, const ITensorInfo *output, bool transpose_weights, bool is_batched_fc_layer);
Anthony Barbier6ff3b192017-09-04 18:44:23 +010070
71 // Inherited methods overridden:
72 void run() override;
73
74private:
Georgios Pinitasbaf174e2017-09-08 19:47:30 +010075 MemoryGroup _memory_group;
Anthony Barbier6ff3b192017-09-04 18:44:23 +010076 NETransposeKernel _transpose_kernel;
77 NEGEMMTranspose1xWKernel _transpose1xW_kernel;
78 Tensor _transpose_output;
79 bool _transpose_weights;
80 bool _is_batched_fc_layer;
81};
82
83/** Basic function to compute a Fully Connected layer on NEON. This function calls the following NEON kernels:
84 * -# @ref NEIm2ColKernel (called when the input comes from a convolutional layer)
85 * -# @ref NEFullyConnectedLayerReshapeWeights (if @p are_weights_reshaped flag is set to false) (called once)
86 * -# @ref NEGEMMInterleave4x4Kernel (called if we have a multi-batch input)
87 * -# @ref NEGEMMMatrixMultiplyKernel
88 * -# @ref NEGEMMMatrixAccumulateBiasesKernel (if @p biases is not equal to nullptr)
89 *
90 * @note The fully connected layer accepts "weights" tensors only with 2 dimensions.
91 */
92class NEFullyConnectedLayer : public IFunction
93{
94public:
95 /** Constructor */
Georgios Pinitasbaf174e2017-09-08 19:47:30 +010096 NEFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager = nullptr);
Georgios Pinitas1562be32018-03-08 19:09:19 +000097 /** Prevent instances of this class from being copied (As this class contains pointers) */
98 NEFullyConnectedLayer(const NEFullyConnectedLayer &) = delete;
99 /** Default move constructor */
100 NEFullyConnectedLayer(NEFullyConnectedLayer &&) = default;
101 /** Prevent instances of this class from being copied (As this class contains pointers) */
102 NEFullyConnectedLayer &operator=(const NEFullyConnectedLayer &) = delete;
103 /** Default move assignment operator */
104 NEFullyConnectedLayer &operator=(NEFullyConnectedLayer &&) = default;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100105 /** Set the input and output tensors.
106 *
Gian Marco Iodice2bbd9642017-07-04 16:46:32 +0100107 * @param[in] input Source tensor. Data type supported: QS8/QS16/F32.
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100108 * @param[in] weights Weights tensor. The weights must be 2 dimensional. Data type supported: Same as @p input.
109 * @param[in] biases Bias tensor. Can be nullptr. Data type supported:Same as @p input.
110 * @param[out] output Destination tensor. Data type supported: Same as @p input.
111 * @param[in] transpose_weights (Optional) Transpose the weights tensor if true. Defaults to true.
112 * @param[in] are_weights_reshaped (Optional) Reshape the weights tensor if false. Defaults to false.
113 */
114 void configure(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, bool transpose_weights = true, bool are_weights_reshaped = false);
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +0000115 /** Static function to check if given info will lead to a valid configuration of @ref CLFullyConnectedLayer
116 *
117 * @param[in] input Source tensor info. Data type supported: QS8/QS16/F16/F32.
118 * @param[in] weights Weights tensor info. The weights must be 2 dimensional. Data type supported: Same as @p input
119 * @param[in] biases Bias tensor info. It can be nullptr. Data type supported:Same as @p input.
120 * @param[in] output Destination tensor info. Data type supported: Same as @p input.
121 * @param[in] transpose_weights (Optional) Transpose weights if true. Defaults to true.
122 * @param[in] are_weights_reshaped (Optional) Reshape the weights tensor if false. Defaults to false.
123 *
124 * @return a status
125 */
126 static Status validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, bool transpose_weights = true, bool are_weights_reshaped = false);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100127
128 //Inherited methods override
129 void run() override;
130
131private:
Georgios Pinitasbaf174e2017-09-08 19:47:30 +0100132 MemoryGroup _memory_group;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100133 NEIm2ColKernel _im2col_kernel;
134 NEFullyConnectedLayerReshapeWeights _reshape_weights_kernel;
135 NEGEMMInterleave4x4Kernel _interleave4x4_kernel;
136 NEGEMMMatrixMultiplyKernel _mm_kernel;
137 NEGEMMMatrixAccumulateBiasesKernel _accumulate_biases_kernel;
138 Tensor _im2col_output;
139 Tensor _interleave4x4_output;
140 Tensor _reshape_weights_output;
141 bool _are_weights_reshaped;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100142 bool _is_batched_fc_layer;
Moritz Pflanzer484e7b32017-08-09 11:43:18 +0100143 bool _linearize_input;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100144 bool _accumulate_biases;
Georgios Pinitas1562be32018-03-08 19:09:19 +0000145 const ITensor *_original_weights;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100146};
Georgios Pinitas1562be32018-03-08 19:09:19 +0000147} // namespace arm_compute
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100148#endif /* __ARM_COMPUTE_NEFULLYCONNECTEDLAYER_H__ */