blob: 4c1b593793ea9a9aa3ccd982229ca8946b381fb5 [file] [log] [blame]
John Kesapides341b2182019-02-22 10:05:29 +00001/*
Michele Di Giorgiod9eaf612020-07-08 11:12:57 +01002 * Copyright (c) 2019-2020 Arm Limited.
John Kesapides341b2182019-02-22 10:05:29 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/graph.h"
25
John Kesapides341b2182019-02-22 10:05:29 +000026#include "tests/NEON/Accessor.h"
27#include "tests/validation/Validation.h"
28#include "tests/validation/reference/FullyConnectedLayer.h"
29#include "tests/validation/reference/Permute.h"
30
31#include "utils/CommonGraphOptions.h"
32#include "utils/GraphUtils.h"
33#include "utils/Utils.h"
34
35#include "ValidateExample.h"
John Kesapides8d942692019-02-26 14:52:12 +000036#include "graph_validate_utils.h"
John Kesapides341b2182019-02-22 10:05:29 +000037
38#include <utility>
39
40using namespace arm_compute::utils;
41using namespace arm_compute::graph::frontend;
42using namespace arm_compute::graph_utils;
43using namespace arm_compute::graph;
44using namespace arm_compute;
45using namespace arm_compute::test;
46using namespace arm_compute::test::validation;
John Kesapides8d942692019-02-26 14:52:12 +000047
John Kesapides341b2182019-02-22 10:05:29 +000048namespace
49{
John Kesapides8d942692019-02-26 14:52:12 +000050/** Fully connected command line options used to configure the graph examples
John Kesapides341b2182019-02-22 10:05:29 +000051 *
52 * (Similar to common options)
53 * The options in this object get populated when "parse()" is called on the parser used to construct it.
54 * The expected workflow is:
55 *
56 * CommandLineParser parser;
57 * CommonOptions options( parser );
58 * parser.parse(argc, argv);
59 */
John Kesapides8d942692019-02-26 14:52:12 +000060class FullyConnectedOptions final : public CommonGraphValidateOptions
John Kesapides341b2182019-02-22 10:05:29 +000061{
62public:
63 explicit FullyConnectedOptions(CommandLineParser &parser) noexcept
John Kesapides8d942692019-02-26 14:52:12 +000064 : CommonGraphValidateOptions(parser),
65 width(parser.add_option<SimpleOption<int>>("width", 3)),
John Kesapides341b2182019-02-22 10:05:29 +000066 batch(parser.add_option<SimpleOption<int>>("batch", 1)),
John Kesapides341b2182019-02-22 10:05:29 +000067 input_scale(parser.add_option<SimpleOption<float>>("input_scale", 1.0f)),
68 input_offset(parser.add_option<SimpleOption<int>>("input_offset", 0)),
69 weights_scale(parser.add_option<SimpleOption<float>>("weights_scale", 1.0f)),
70 weights_offset(parser.add_option<SimpleOption<int>>("weights_offset", 0)),
71 output_scale(parser.add_option<SimpleOption<float>>("output_scale", 1.0f)),
72 output_offset(parser.add_option<SimpleOption<int>>("output_offset", 0)),
73 num_outputs(parser.add_option<SimpleOption<int>>("num_outputs", 1)),
74 input_range_low(parser.add_option<SimpleOption<uint64_t>>("input_range_low")),
75 input_range_high(parser.add_option<SimpleOption<uint64_t>>("input_range_high")),
76 weights_range_low(parser.add_option<SimpleOption<uint64_t>>("weights_range_low")),
77 weights_range_high(parser.add_option<SimpleOption<uint64_t>>("weights_range_high"))
78 {
John Kesapides341b2182019-02-22 10:05:29 +000079 width->set_help("Set Input dimension width");
80 batch->set_help("Set Input dimension batch");
John Kesapides341b2182019-02-22 10:05:29 +000081 input_scale->set_help("Quantization scale from QASYMM8");
82 input_offset->set_help("Quantization offset from QASYMM8");
83 weights_scale->set_help("Quantization scale from QASYMM8");
84 weights_offset->set_help("Quantization offset from QASYMM8");
85 output_scale->set_help("Quantization scale from QASYMM8");
86 output_offset->set_help("Quantization offset from QASYMM8");
87 num_outputs->set_help("Number of outputs.");
88 input_range_low->set_help("Lower bound for input randomization range");
89 input_range_high->set_help("Lower bound for input randomization range");
90 weights_range_low->set_help("Lower bound for input randomization range");
91 weights_range_high->set_help("Lower bound for input randomization range");
92 }
93
John Kesapides8d942692019-02-26 14:52:12 +000094 /** Fill out the supplied parameters with user supplied parameters
95 *
96 * @param[out] os Output stream.
97 * @param[in] common_params Example parameters to output
98 *
99 * @return None.
100 */
101 void consume_parameters(ExampleParams &common_params)
102 {
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100103 common_params.input.width = width->value();
104 common_params.input.batch = batch->value();
105 common_params.input.quant_info = QuantizationInfo(input_scale->value(), input_offset->value());
106 common_params.input.range_low = input_range_low->value();
107 common_params.input.range_high = input_range_high->value();
John Kesapides8d942692019-02-26 14:52:12 +0000108
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100109 common_params.weights.quant_info = QuantizationInfo(weights_scale->value(), weights_offset->value());
110 common_params.weights.range_low = weights_range_low->value();
111 common_params.weights.range_high = weights_range_high->value();
John Kesapides8d942692019-02-26 14:52:12 +0000112
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100113 common_params.output.quant_info = QuantizationInfo(output_scale->value(), output_offset->value());
John Kesapides8d942692019-02-26 14:52:12 +0000114
115 common_params.data_type = data_type->value();
116 common_params.fully_connected.num_outputs = num_outputs->value();
117 }
118
119 void print_parameters(::std::ostream &os, const ExampleParams &common_params) override
120 {
121 os << "Threads : " << common_params.common_params.threads << std::endl;
122 os << "Target : " << common_params.common_params.target << std::endl;
123 os << "Data type : " << common_params.data_type << std::endl;
124 os << "Input dimensions(X,Y, Channels, Batch) : (" << common_params.input.width << "," << common_params.input.height << "," << common_params.input.fm << "," << common_params.input.batch << ")"
125 << std::endl;
126 os << "Number of outputs : " << common_params.fully_connected.num_outputs << std::endl;
127 }
128
John Kesapides341b2182019-02-22 10:05:29 +0000129 /** Prevent instances of this class from being copied (As this class contains pointers) */
130 FullyConnectedOptions(const FullyConnectedOptions &) = delete;
131 /** Prevent instances of this class from being copied (As this class contains pointers) */
132 FullyConnectedOptions &operator=(const FullyConnectedOptions &) = delete;
133 /** Allow instances of this class to be moved */
134 FullyConnectedOptions(FullyConnectedOptions &&) noexcept(true) = default;
135 /** Allow instances of this class to be moved */
136 FullyConnectedOptions &operator=(FullyConnectedOptions &&) noexcept(true) = default;
137 /** Default destructor */
John Kesapides8d942692019-02-26 14:52:12 +0000138 ~FullyConnectedOptions() override = default;
John Kesapides341b2182019-02-22 10:05:29 +0000139
Michalis Spyrou299fdd32019-05-01 13:03:59 +0100140private:
John Kesapides8d942692019-02-26 14:52:12 +0000141 SimpleOption<int> *width; /**< Input width */
142 SimpleOption<int> *batch; /**< Input batch */
143 SimpleOption<float> *input_scale; /**< Input Quantization scale from QASSYMM8 */
144 SimpleOption<int> *input_offset; /**< Input Quantization offset from QASSYMM8 */
145 SimpleOption<float> *weights_scale; /**< Weights Quantization scale from QASSYMM8 */
146 SimpleOption<int> *weights_offset; /**< Weights Quantization offset from QASSYMM8 */
147 SimpleOption<float> *output_scale; /**< Output Quantization scale from QASSYMM8 */
148 SimpleOption<int> *output_offset; /**< Output Quantization offset from QASSYMM8 */
149 SimpleOption<int> *num_outputs; /**< Number of outputs. */
150 SimpleOption<uint64_t> *input_range_low; /**< Lower bound for input randomization range */
151 SimpleOption<uint64_t> *input_range_high; /**< Upper bound for input randomization range */
152 SimpleOption<uint64_t> *weights_range_low; /**< Lower bound for weights randomization range */
153 SimpleOption<uint64_t> *weights_range_high; /**< Upper bound for weights randomization range */
John Kesapides341b2182019-02-22 10:05:29 +0000154};
155
John Kesapides8d942692019-02-26 14:52:12 +0000156/** Fully Connected Layer Graph example validation accessor class */
John Kesapides341b2182019-02-22 10:05:29 +0000157template <typename D>
John Kesapides8d942692019-02-26 14:52:12 +0000158class FullyConnectedVerifyAccessor final : public VerifyAccessor<D>
John Kesapides341b2182019-02-22 10:05:29 +0000159{
John Kesapides8d942692019-02-26 14:52:12 +0000160 using BaseClassType = VerifyAccessor<D>;
161 using BaseClassType::BaseClassType;
162 using BaseClassType::_params;
John Kesapides341b2182019-02-22 10:05:29 +0000163 using TBias = typename std::conditional<std::is_same<typename std::decay<D>::type, uint8_t>::value, int32_t, D>::type;
164
John Kesapides8d942692019-02-26 14:52:12 +0000165 // Inherited methods overriden:
166 void create_tensors(arm_compute::test::SimpleTensor<D> &src,
167 arm_compute::test::SimpleTensor<D> &weights,
168 arm_compute::test::SimpleTensor<TBias> &bias,
169 ITensor &tensor) override
John Kesapides341b2182019-02-22 10:05:29 +0000170 {
John Kesapides341b2182019-02-22 10:05:29 +0000171 // Calculate Tensor shapes for verification
172 const TensorShape input_shape = TensorShape(_params.input.width, _params.input.height, _params.input.fm, _params.input.batch);
John Kesapides8d942692019-02-26 14:52:12 +0000173 const TensorDescriptor input_descriptor = TensorDescriptor(input_shape, _params.data_type, _params.input.quant_info);
John Kesapides341b2182019-02-22 10:05:29 +0000174 const TensorDescriptor weights_descriptor = FullyConnectedLayerNode::compute_weights_descriptor(input_descriptor,
175 _params.fully_connected.num_outputs,
176 _params.fully_connected.info,
177 _params.weights.quant_info);
178 const TensorDescriptor output_desciptor = FullyConnectedLayerNode::compute_output_descriptor(input_descriptor, _params.fully_connected.num_outputs, _params.output.quant_info);
179
180 //Create Input tensors
John Kesapides8d942692019-02-26 14:52:12 +0000181 src = SimpleTensor<D> { input_descriptor.shape, _params.data_type, 1, input_descriptor.quant_info };
182 weights = SimpleTensor<D> { weights_descriptor.shape, _params.data_type, 1, weights_descriptor.quant_info };
183 bias = SimpleTensor<TBias> { TensorShape(tensor.info()->tensor_shape().x()), _params.data_type, 1, _params.input.quant_info };
John Kesapides341b2182019-02-22 10:05:29 +0000184 }
185
John Kesapides8d942692019-02-26 14:52:12 +0000186 TensorShape output_shape(ITensor &tensor) override
John Kesapides341b2182019-02-22 10:05:29 +0000187 {
John Kesapides8d942692019-02-26 14:52:12 +0000188 ARM_COMPUTE_UNUSED(tensor);
John Kesapides341b2182019-02-22 10:05:29 +0000189
John Kesapides8d942692019-02-26 14:52:12 +0000190 const TensorShape input_shape = TensorShape(_params.input.width, _params.input.height, _params.input.fm, _params.input.batch);
191 const TensorDescriptor input_descriptor = TensorDescriptor(input_shape, _params.data_type, _params.input.quant_info);
192 const TensorDescriptor output_desciptor = FullyConnectedLayerNode::compute_output_descriptor(input_descriptor, _params.fully_connected.num_outputs, _params.output.quant_info);
John Kesapides341b2182019-02-22 10:05:29 +0000193
John Kesapides8d942692019-02-26 14:52:12 +0000194 return output_desciptor.shape;
John Kesapides341b2182019-02-22 10:05:29 +0000195 }
John Kesapides8d942692019-02-26 14:52:12 +0000196
197 arm_compute::test::SimpleTensor<D> reference(arm_compute::test::SimpleTensor<D> &src,
198 arm_compute::test::SimpleTensor<D> &weights,
199 arm_compute::test::SimpleTensor<TBias> &bias,
200 const arm_compute::TensorShape &output_shape) override
201 {
202 return reference::fully_connected_layer<D>(src, weights, bias, output_shape, _params.output.quant_info);
203 }
204
205 float relative_tolerance() override
John Kesapides341b2182019-02-22 10:05:29 +0000206 {
207 const std::map<arm_compute::graph::Target, const std::map<DataType, float>> relative_tolerance
208 {
209 {
210 arm_compute::graph::Target::CL,
211 { { DataType::F16, 0.2f },
212 { DataType::F32, 0.05f },
213 { DataType::QASYMM8, 1.0f }
214 }
215 },
216 {
217 arm_compute::graph::Target::NEON,
218 { { DataType::F16, 0.2f },
219 { DataType::F32, 0.01f },
220 { DataType::QASYMM8, 1.0f }
221 }
222 }
223 };
John Kesapides341b2182019-02-22 10:05:29 +0000224
John Kesapides8d942692019-02-26 14:52:12 +0000225 return relative_tolerance.at(_params.common_params.target).at(_params.data_type);
John Kesapides341b2182019-02-22 10:05:29 +0000226 }
227
John Kesapides8d942692019-02-26 14:52:12 +0000228 float absolute_tolerance() override
John Kesapides341b2182019-02-22 10:05:29 +0000229 {
230 const std::map<Target, const std::map<DataType, float>> absolute_tolerance
231 {
232 {
233 Target::CL,
234 { { DataType::F16, 0.0f },
235 { DataType::F32, 0.0001f },
236 { DataType::QASYMM8, 1.0f }
237 }
238 },
239 {
240 Target::NEON,
241 { { DataType::F16, 0.3f },
242 { DataType::F32, 0.1f },
243 { DataType::QASYMM8, 1.0f }
244 }
245 }
246 };
247
John Kesapides8d942692019-02-26 14:52:12 +0000248 return absolute_tolerance.at(_params.common_params.target).at(_params.data_type);
John Kesapides341b2182019-02-22 10:05:29 +0000249 }
John Kesapides8d942692019-02-26 14:52:12 +0000250
251 float tolerance_number() override
John Kesapides341b2182019-02-22 10:05:29 +0000252 {
253 const std::map<Target, const std::map<DataType, float>> absolute_tolerance
254 {
255 {
256 Target::CL,
257 { { DataType::F16, 0.07f },
258 { DataType::F32, 0.07f },
259 { DataType::QASYMM8, 0.0f }
260 }
261 },
262 {
263 Target::NEON,
264 { { DataType::F16, 0.07f },
265 { DataType::F32, 0.0f },
266 { DataType::QASYMM8, 0.0f }
267 }
268 }
269 };
270
John Kesapides8d942692019-02-26 14:52:12 +0000271 return absolute_tolerance.at(_params.common_params.target).at(_params.data_type);
John Kesapides341b2182019-02-22 10:05:29 +0000272 }
John Kesapides341b2182019-02-22 10:05:29 +0000273};
274
John Kesapides341b2182019-02-22 10:05:29 +0000275} // namespace
276
John Kesapides8d942692019-02-26 14:52:12 +0000277class GraphFullyConnectedValidateExample final : public GraphValidateExample<FullyConnectedLayer, FullyConnectedOptions, FullyConnectedVerifyAccessor>
John Kesapides341b2182019-02-22 10:05:29 +0000278{
John Kesapides8d942692019-02-26 14:52:12 +0000279 using GraphValidateExample::graph;
280
John Kesapides341b2182019-02-22 10:05:29 +0000281public:
John Kesapides8d942692019-02-26 14:52:12 +0000282 GraphFullyConnectedValidateExample()
283 : GraphValidateExample("Fully_connected Graph example")
John Kesapides341b2182019-02-22 10:05:29 +0000284 {
285 }
286
John Kesapides8d942692019-02-26 14:52:12 +0000287 FullyConnectedLayer GraphFunctionLayer(ExampleParams &params) override
288 {
289 const PixelValue lower = PixelValue(params.input.range_low, params.data_type, params.input.quant_info);
290 const PixelValue upper = PixelValue(params.input.range_high, params.data_type, params.input.quant_info);
291
292 const PixelValue weights_lower = PixelValue(params.weights.range_low, params.data_type, params.weights.quant_info);
293 const PixelValue weights_upper = PixelValue(params.weights.range_high, params.data_type, params.weights.quant_info);
294
295 return FullyConnectedLayer(params.fully_connected.num_outputs,
296 get_random_accessor(weights_lower, weights_upper, 1),
297 get_random_accessor(lower, upper, 2),
298 params.fully_connected.info, params.weights.quant_info, params.output.quant_info);
299 }
John Kesapides341b2182019-02-22 10:05:29 +0000300};
301
302/** Main program for Graph fully_connected test
303 *
304 * @param[in] argc Number of arguments
305 * @param[in] argv Arguments ( Input dimensions [width, batch]
306 * Fully connected [num_outputs,type]
307 * Verification[tolerance_number,absolute_tolerance,relative_tolerance] )
308 *
309 */
310int main(int argc, char **argv)
311{
John Kesapides8d942692019-02-26 14:52:12 +0000312 return arm_compute::utils::run_example<GraphFullyConnectedValidateExample>(argc, argv);
John Kesapides341b2182019-02-22 10:05:29 +0000313}