blob: b19c40d5ea98eb4601f3d4afef8de7a5ac90c562 [file] [log] [blame]
Moritz Pflanzer69d33412017-08-09 11:45:15 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef ARM_COMPUTE_TEST_FULLY_CONNECTED_LAYER_FIXTURE
25#define ARM_COMPUTE_TEST_FULLY_CONNECTED_LAYER_FIXTURE
26
27#include "arm_compute/core/TensorShape.h"
28#include "arm_compute/core/Types.h"
29#include "arm_compute/core/Utils.h"
Moritz Pflanzer69d33412017-08-09 11:45:15 +010030#include "tests/AssetsLibrary.h"
31#include "tests/Globals.h"
32#include "tests/IAccessor.h"
33#include "tests/RawTensor.h"
Moritz Pflanzera09de0c2017-09-01 20:41:12 +010034#include "tests/framework/Asserts.h"
35#include "tests/framework/Fixture.h"
36#include "tests/validation/CPP/FullyConnectedLayer.h"
Moritz Pflanzercde1e8a2017-09-08 09:53:14 +010037#include "tests/validation/CPP/Utils.h"
Moritz Pflanzera09de0c2017-09-01 20:41:12 +010038#include "tests/validation/Helpers.h"
Moritz Pflanzer69d33412017-08-09 11:45:15 +010039
40#include <random>
41
42namespace arm_compute
43{
44namespace test
45{
46namespace validation
47{
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +010048template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool run_interleave>
Moritz Pflanzer69d33412017-08-09 11:45:15 +010049class FullyConnectedLayerValidationFixedPointFixture : public framework::Fixture
50{
51public:
52 template <typename...>
53 void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, bool transpose_weights, bool reshape_weights, DataType data_type, int fractional_bits)
54 {
55 ARM_COMPUTE_UNUSED(weights_shape);
56 ARM_COMPUTE_UNUSED(bias_shape);
57
58 _fractional_bits = fractional_bits;
59 _data_type = data_type;
60
61 _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, transpose_weights, reshape_weights, data_type, fractional_bits);
62 _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, transpose_weights, reshape_weights, data_type, fractional_bits);
63 }
64
65protected:
66 template <typename U>
67 void fill(U &&tensor, int i)
68 {
69 if(is_data_type_float(_data_type))
70 {
71 std::uniform_real_distribution<> distribution(0.5f, 1.f);
72 library->fill(tensor, distribution, i);
73 }
74 else
75 {
76 library->fill_tensor_uniform(tensor, i);
77 }
78 }
79
80 TensorType compute_target(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, bool transpose_weights,
81 bool reshape_weights, DataType data_type, int fixed_point_position)
82 {
83 TensorShape reshaped_weights_shape(weights_shape);
84
85 // Test actions depending on the target settings
86 //
87 // | reshape | !reshape
88 // -----------+-----------+---------------------------
89 // transpose | | ***
90 // -----------+-----------+---------------------------
91 // !transpose | transpose | transpose &
92 // | | transpose1xW (if required)
93 //
94 // ***: That combination is invalid. But we can ignore the transpose flag and handle all !reshape the same
95 if(!reshape_weights || !transpose_weights)
96 {
97 const size_t shape_x = reshaped_weights_shape.x();
98 reshaped_weights_shape.set(0, reshaped_weights_shape.y());
99 reshaped_weights_shape.set(1, shape_x);
100
101 // Weights have to be passed reshaped
102 // Transpose 1xW for batched version
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100103 if(!reshape_weights && output_shape.y() > 1 && run_interleave)
Moritz Pflanzer69d33412017-08-09 11:45:15 +0100104 {
105 const int transpose_width = 16 / data_size_from_type(data_type);
106 const float shape_x = reshaped_weights_shape.x();
107 reshaped_weights_shape.set(0, reshaped_weights_shape.y() * transpose_width);
108 reshaped_weights_shape.set(1, static_cast<unsigned int>(std::ceil(shape_x / transpose_width)));
109 }
110 }
111
112 // Create tensors
113 TensorType src = create_tensor<TensorType>(input_shape, data_type, 1, fixed_point_position);
114 TensorType weights = create_tensor<TensorType>(reshaped_weights_shape, data_type, 1, fixed_point_position);
115 TensorType bias = create_tensor<TensorType>(bias_shape, data_type, 1, fixed_point_position);
116 TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1, fixed_point_position);
117
118 // Create and configure function.
119 FunctionType fc;
120 fc.configure(&src, &weights, &bias, &dst, transpose_weights, !reshape_weights);
121
122 ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
123 ARM_COMPUTE_EXPECT(weights.info()->is_resizable(), framework::LogLevel::ERRORS);
124 ARM_COMPUTE_EXPECT(bias.info()->is_resizable(), framework::LogLevel::ERRORS);
125 ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
126
127 // Allocate tensors
128 src.allocator()->allocate();
129 weights.allocator()->allocate();
130 bias.allocator()->allocate();
131 dst.allocator()->allocate();
132
133 ARM_COMPUTE_EXPECT(!src.info()->is_resizable(), framework::LogLevel::ERRORS);
134 ARM_COMPUTE_EXPECT(!weights.info()->is_resizable(), framework::LogLevel::ERRORS);
135 ARM_COMPUTE_EXPECT(!bias.info()->is_resizable(), framework::LogLevel::ERRORS);
136 ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS);
137
138 // Fill tensors
139 fill(AccessorType(src), 0);
140 fill(AccessorType(bias), 2);
141
142 if(!reshape_weights || !transpose_weights)
143 {
144 TensorShape tmp_shape(weights_shape);
145 RawTensor tmp(tmp_shape, data_type, 1, fixed_point_position);
146
147 // Fill with original shape
148 fill(tmp, 1);
149
150 // Transpose elementwise
151 tmp = transpose(tmp);
152
153 // Reshape weights for batched runs
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100154 if(!reshape_weights && output_shape.y() > 1 && run_interleave)
Moritz Pflanzer69d33412017-08-09 11:45:15 +0100155 {
156 // Transpose with interleave
157 const int interleave_size = 16 / tmp.element_size();
158 tmp = transpose(tmp, interleave_size);
159 }
160
161 AccessorType weights_accessor(weights);
162
163 for(int i = 0; i < tmp.num_elements(); ++i)
164 {
165 Coordinates coord = index2coord(tmp.shape(), i);
166 std::copy_n(static_cast<const RawTensor::value_type *>(tmp(coord)),
167 tmp.element_size(),
168 static_cast<RawTensor::value_type *>(weights_accessor(coord)));
169 }
170 }
171 else
172 {
173 fill(AccessorType(weights), 1);
174 }
175
176 // Compute NEFullyConnectedLayer function
177 fc.run();
178
179 return dst;
180 }
181
182 SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, bool transpose_weights,
183 bool reshape_weights, DataType data_type, int fixed_point_position = 0)
184 {
185 // Create reference
186 SimpleTensor<T> src{ input_shape, data_type, 1, fixed_point_position };
187 SimpleTensor<T> weights{ weights_shape, data_type, 1, fixed_point_position };
188 SimpleTensor<T> bias{ bias_shape, data_type, 1, fixed_point_position };
189
190 // Fill reference
191 fill(src, 0);
192 fill(weights, 1);
193 fill(bias, 2);
194
195 return reference::fully_connected_layer<T>(src, weights, bias, output_shape);
196 }
197
198 TensorType _target{};
199 SimpleTensor<T> _reference{};
200 int _fractional_bits{};
201 DataType _data_type{};
202};
203
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100204template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool run_interleave>
205class FullyConnectedLayerValidationFixture : public FullyConnectedLayerValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T, run_interleave>
Moritz Pflanzer69d33412017-08-09 11:45:15 +0100206{
207public:
208 template <typename...>
209 void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, bool transpose_weights, bool reshape_weights, DataType data_type)
210 {
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100211 FullyConnectedLayerValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T, run_interleave>::setup(input_shape, weights_shape, bias_shape, output_shape, transpose_weights,
212 reshape_weights, data_type,
213 0);
Moritz Pflanzer69d33412017-08-09 11:45:15 +0100214 }
215};
216} // namespace validation
217} // namespace test
218} // namespace arm_compute
219#endif /* ARM_COMPUTE_TEST_FULLY_CONNECTED_LAYER_FIXTURE */