blob: 645fa8b12466f8c3e1cdd1564687d234a0afd200 [file] [log] [blame]
John Kesapides341b2182019-02-22 10:05:29 +00001/*
2 * Copyright (c) 2019 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/graph.h"
25
26#include "support/ToolchainSupport.h"
27
28#include "tests/NEON/Accessor.h"
29#include "tests/validation/Validation.h"
30#include "tests/validation/reference/FullyConnectedLayer.h"
31#include "tests/validation/reference/Permute.h"
32
33#include "utils/CommonGraphOptions.h"
34#include "utils/GraphUtils.h"
35#include "utils/Utils.h"
36
37#include "ValidateExample.h"
John Kesapides8d942692019-02-26 14:52:12 +000038#include "graph_validate_utils.h"
John Kesapides341b2182019-02-22 10:05:29 +000039
40#include <utility>
41
42using namespace arm_compute::utils;
43using namespace arm_compute::graph::frontend;
44using namespace arm_compute::graph_utils;
45using namespace arm_compute::graph;
46using namespace arm_compute;
47using namespace arm_compute::test;
48using namespace arm_compute::test::validation;
John Kesapides8d942692019-02-26 14:52:12 +000049
John Kesapides341b2182019-02-22 10:05:29 +000050namespace
51{
John Kesapides8d942692019-02-26 14:52:12 +000052/** Fully connected command line options used to configure the graph examples
John Kesapides341b2182019-02-22 10:05:29 +000053 *
54 * (Similar to common options)
55 * The options in this object get populated when "parse()" is called on the parser used to construct it.
56 * The expected workflow is:
57 *
58 * CommandLineParser parser;
59 * CommonOptions options( parser );
60 * parser.parse(argc, argv);
61 */
John Kesapides8d942692019-02-26 14:52:12 +000062class FullyConnectedOptions final : public CommonGraphValidateOptions
John Kesapides341b2182019-02-22 10:05:29 +000063{
64public:
65 explicit FullyConnectedOptions(CommandLineParser &parser) noexcept
John Kesapides8d942692019-02-26 14:52:12 +000066 : CommonGraphValidateOptions(parser),
67 width(parser.add_option<SimpleOption<int>>("width", 3)),
John Kesapides341b2182019-02-22 10:05:29 +000068 batch(parser.add_option<SimpleOption<int>>("batch", 1)),
John Kesapides341b2182019-02-22 10:05:29 +000069 input_scale(parser.add_option<SimpleOption<float>>("input_scale", 1.0f)),
70 input_offset(parser.add_option<SimpleOption<int>>("input_offset", 0)),
71 weights_scale(parser.add_option<SimpleOption<float>>("weights_scale", 1.0f)),
72 weights_offset(parser.add_option<SimpleOption<int>>("weights_offset", 0)),
73 output_scale(parser.add_option<SimpleOption<float>>("output_scale", 1.0f)),
74 output_offset(parser.add_option<SimpleOption<int>>("output_offset", 0)),
75 num_outputs(parser.add_option<SimpleOption<int>>("num_outputs", 1)),
76 input_range_low(parser.add_option<SimpleOption<uint64_t>>("input_range_low")),
77 input_range_high(parser.add_option<SimpleOption<uint64_t>>("input_range_high")),
78 weights_range_low(parser.add_option<SimpleOption<uint64_t>>("weights_range_low")),
79 weights_range_high(parser.add_option<SimpleOption<uint64_t>>("weights_range_high"))
80 {
John Kesapides341b2182019-02-22 10:05:29 +000081 width->set_help("Set Input dimension width");
82 batch->set_help("Set Input dimension batch");
John Kesapides341b2182019-02-22 10:05:29 +000083 input_scale->set_help("Quantization scale from QASYMM8");
84 input_offset->set_help("Quantization offset from QASYMM8");
85 weights_scale->set_help("Quantization scale from QASYMM8");
86 weights_offset->set_help("Quantization offset from QASYMM8");
87 output_scale->set_help("Quantization scale from QASYMM8");
88 output_offset->set_help("Quantization offset from QASYMM8");
89 num_outputs->set_help("Number of outputs.");
90 input_range_low->set_help("Lower bound for input randomization range");
91 input_range_high->set_help("Lower bound for input randomization range");
92 weights_range_low->set_help("Lower bound for input randomization range");
93 weights_range_high->set_help("Lower bound for input randomization range");
94 }
95
John Kesapides8d942692019-02-26 14:52:12 +000096 /** Fill out the supplied parameters with user supplied parameters
97 *
98 * @param[out] os Output stream.
99 * @param[in] common_params Example parameters to output
100 *
101 * @return None.
102 */
103 void consume_parameters(ExampleParams &common_params)
104 {
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100105 common_params.input.width = width->value();
106 common_params.input.batch = batch->value();
107 common_params.input.quant_info = QuantizationInfo(input_scale->value(), input_offset->value());
108 common_params.input.range_low = input_range_low->value();
109 common_params.input.range_high = input_range_high->value();
John Kesapides8d942692019-02-26 14:52:12 +0000110
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100111 common_params.weights.quant_info = QuantizationInfo(weights_scale->value(), weights_offset->value());
112 common_params.weights.range_low = weights_range_low->value();
113 common_params.weights.range_high = weights_range_high->value();
John Kesapides8d942692019-02-26 14:52:12 +0000114
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100115 common_params.output.quant_info = QuantizationInfo(output_scale->value(), output_offset->value());
John Kesapides8d942692019-02-26 14:52:12 +0000116
117 common_params.data_type = data_type->value();
118 common_params.fully_connected.num_outputs = num_outputs->value();
119 }
120
121 void print_parameters(::std::ostream &os, const ExampleParams &common_params) override
122 {
123 os << "Threads : " << common_params.common_params.threads << std::endl;
124 os << "Target : " << common_params.common_params.target << std::endl;
125 os << "Data type : " << common_params.data_type << std::endl;
126 os << "Input dimensions(X,Y, Channels, Batch) : (" << common_params.input.width << "," << common_params.input.height << "," << common_params.input.fm << "," << common_params.input.batch << ")"
127 << std::endl;
128 os << "Number of outputs : " << common_params.fully_connected.num_outputs << std::endl;
129 }
130
John Kesapides341b2182019-02-22 10:05:29 +0000131 /** Prevent instances of this class from being copied (As this class contains pointers) */
132 FullyConnectedOptions(const FullyConnectedOptions &) = delete;
133 /** Prevent instances of this class from being copied (As this class contains pointers) */
134 FullyConnectedOptions &operator=(const FullyConnectedOptions &) = delete;
135 /** Allow instances of this class to be moved */
136 FullyConnectedOptions(FullyConnectedOptions &&) noexcept(true) = default;
137 /** Allow instances of this class to be moved */
138 FullyConnectedOptions &operator=(FullyConnectedOptions &&) noexcept(true) = default;
139 /** Default destructor */
John Kesapides8d942692019-02-26 14:52:12 +0000140 ~FullyConnectedOptions() override = default;
John Kesapides341b2182019-02-22 10:05:29 +0000141
Michalis Spyrou299fdd32019-05-01 13:03:59 +0100142private:
John Kesapides8d942692019-02-26 14:52:12 +0000143 SimpleOption<int> *width; /**< Input width */
144 SimpleOption<int> *batch; /**< Input batch */
145 SimpleOption<float> *input_scale; /**< Input Quantization scale from QASSYMM8 */
146 SimpleOption<int> *input_offset; /**< Input Quantization offset from QASSYMM8 */
147 SimpleOption<float> *weights_scale; /**< Weights Quantization scale from QASSYMM8 */
148 SimpleOption<int> *weights_offset; /**< Weights Quantization offset from QASSYMM8 */
149 SimpleOption<float> *output_scale; /**< Output Quantization scale from QASSYMM8 */
150 SimpleOption<int> *output_offset; /**< Output Quantization offset from QASSYMM8 */
151 SimpleOption<int> *num_outputs; /**< Number of outputs. */
152 SimpleOption<uint64_t> *input_range_low; /**< Lower bound for input randomization range */
153 SimpleOption<uint64_t> *input_range_high; /**< Upper bound for input randomization range */
154 SimpleOption<uint64_t> *weights_range_low; /**< Lower bound for weights randomization range */
155 SimpleOption<uint64_t> *weights_range_high; /**< Upper bound for weights randomization range */
John Kesapides341b2182019-02-22 10:05:29 +0000156};
157
John Kesapides8d942692019-02-26 14:52:12 +0000158/** Fully Connected Layer Graph example validation accessor class */
John Kesapides341b2182019-02-22 10:05:29 +0000159template <typename D>
John Kesapides8d942692019-02-26 14:52:12 +0000160class FullyConnectedVerifyAccessor final : public VerifyAccessor<D>
John Kesapides341b2182019-02-22 10:05:29 +0000161{
John Kesapides8d942692019-02-26 14:52:12 +0000162 using BaseClassType = VerifyAccessor<D>;
163 using BaseClassType::BaseClassType;
164 using BaseClassType::_params;
John Kesapides341b2182019-02-22 10:05:29 +0000165 using TBias = typename std::conditional<std::is_same<typename std::decay<D>::type, uint8_t>::value, int32_t, D>::type;
166
John Kesapides8d942692019-02-26 14:52:12 +0000167 // Inherited methods overriden:
168 void create_tensors(arm_compute::test::SimpleTensor<D> &src,
169 arm_compute::test::SimpleTensor<D> &weights,
170 arm_compute::test::SimpleTensor<TBias> &bias,
171 ITensor &tensor) override
John Kesapides341b2182019-02-22 10:05:29 +0000172 {
John Kesapides341b2182019-02-22 10:05:29 +0000173 // Calculate Tensor shapes for verification
174 const TensorShape input_shape = TensorShape(_params.input.width, _params.input.height, _params.input.fm, _params.input.batch);
John Kesapides8d942692019-02-26 14:52:12 +0000175 const TensorDescriptor input_descriptor = TensorDescriptor(input_shape, _params.data_type, _params.input.quant_info);
John Kesapides341b2182019-02-22 10:05:29 +0000176 const TensorDescriptor weights_descriptor = FullyConnectedLayerNode::compute_weights_descriptor(input_descriptor,
177 _params.fully_connected.num_outputs,
178 _params.fully_connected.info,
179 _params.weights.quant_info);
180 const TensorDescriptor output_desciptor = FullyConnectedLayerNode::compute_output_descriptor(input_descriptor, _params.fully_connected.num_outputs, _params.output.quant_info);
181
182 //Create Input tensors
John Kesapides8d942692019-02-26 14:52:12 +0000183 src = SimpleTensor<D> { input_descriptor.shape, _params.data_type, 1, input_descriptor.quant_info };
184 weights = SimpleTensor<D> { weights_descriptor.shape, _params.data_type, 1, weights_descriptor.quant_info };
185 bias = SimpleTensor<TBias> { TensorShape(tensor.info()->tensor_shape().x()), _params.data_type, 1, _params.input.quant_info };
John Kesapides341b2182019-02-22 10:05:29 +0000186 }
187
John Kesapides8d942692019-02-26 14:52:12 +0000188 TensorShape output_shape(ITensor &tensor) override
John Kesapides341b2182019-02-22 10:05:29 +0000189 {
John Kesapides8d942692019-02-26 14:52:12 +0000190 ARM_COMPUTE_UNUSED(tensor);
John Kesapides341b2182019-02-22 10:05:29 +0000191
John Kesapides8d942692019-02-26 14:52:12 +0000192 const TensorShape input_shape = TensorShape(_params.input.width, _params.input.height, _params.input.fm, _params.input.batch);
193 const TensorDescriptor input_descriptor = TensorDescriptor(input_shape, _params.data_type, _params.input.quant_info);
194 const TensorDescriptor output_desciptor = FullyConnectedLayerNode::compute_output_descriptor(input_descriptor, _params.fully_connected.num_outputs, _params.output.quant_info);
John Kesapides341b2182019-02-22 10:05:29 +0000195
John Kesapides8d942692019-02-26 14:52:12 +0000196 return output_desciptor.shape;
John Kesapides341b2182019-02-22 10:05:29 +0000197 }
John Kesapides8d942692019-02-26 14:52:12 +0000198
199 arm_compute::test::SimpleTensor<D> reference(arm_compute::test::SimpleTensor<D> &src,
200 arm_compute::test::SimpleTensor<D> &weights,
201 arm_compute::test::SimpleTensor<TBias> &bias,
202 const arm_compute::TensorShape &output_shape) override
203 {
204 return reference::fully_connected_layer<D>(src, weights, bias, output_shape, _params.output.quant_info);
205 }
206
207 float relative_tolerance() override
John Kesapides341b2182019-02-22 10:05:29 +0000208 {
209 const std::map<arm_compute::graph::Target, const std::map<DataType, float>> relative_tolerance
210 {
211 {
212 arm_compute::graph::Target::CL,
213 { { DataType::F16, 0.2f },
214 { DataType::F32, 0.05f },
215 { DataType::QASYMM8, 1.0f }
216 }
217 },
218 {
219 arm_compute::graph::Target::NEON,
220 { { DataType::F16, 0.2f },
221 { DataType::F32, 0.01f },
222 { DataType::QASYMM8, 1.0f }
223 }
224 }
225 };
John Kesapides341b2182019-02-22 10:05:29 +0000226
John Kesapides8d942692019-02-26 14:52:12 +0000227 return relative_tolerance.at(_params.common_params.target).at(_params.data_type);
John Kesapides341b2182019-02-22 10:05:29 +0000228 }
229
John Kesapides8d942692019-02-26 14:52:12 +0000230 float absolute_tolerance() override
John Kesapides341b2182019-02-22 10:05:29 +0000231 {
232 const std::map<Target, const std::map<DataType, float>> absolute_tolerance
233 {
234 {
235 Target::CL,
236 { { DataType::F16, 0.0f },
237 { DataType::F32, 0.0001f },
238 { DataType::QASYMM8, 1.0f }
239 }
240 },
241 {
242 Target::NEON,
243 { { DataType::F16, 0.3f },
244 { DataType::F32, 0.1f },
245 { DataType::QASYMM8, 1.0f }
246 }
247 }
248 };
249
John Kesapides8d942692019-02-26 14:52:12 +0000250 return absolute_tolerance.at(_params.common_params.target).at(_params.data_type);
John Kesapides341b2182019-02-22 10:05:29 +0000251 }
John Kesapides8d942692019-02-26 14:52:12 +0000252
253 float tolerance_number() override
John Kesapides341b2182019-02-22 10:05:29 +0000254 {
255 const std::map<Target, const std::map<DataType, float>> absolute_tolerance
256 {
257 {
258 Target::CL,
259 { { DataType::F16, 0.07f },
260 { DataType::F32, 0.07f },
261 { DataType::QASYMM8, 0.0f }
262 }
263 },
264 {
265 Target::NEON,
266 { { DataType::F16, 0.07f },
267 { DataType::F32, 0.0f },
268 { DataType::QASYMM8, 0.0f }
269 }
270 }
271 };
272
John Kesapides8d942692019-02-26 14:52:12 +0000273 return absolute_tolerance.at(_params.common_params.target).at(_params.data_type);
John Kesapides341b2182019-02-22 10:05:29 +0000274 }
John Kesapides341b2182019-02-22 10:05:29 +0000275};
276
John Kesapides341b2182019-02-22 10:05:29 +0000277} // namespace
278
John Kesapides8d942692019-02-26 14:52:12 +0000279class GraphFullyConnectedValidateExample final : public GraphValidateExample<FullyConnectedLayer, FullyConnectedOptions, FullyConnectedVerifyAccessor>
John Kesapides341b2182019-02-22 10:05:29 +0000280{
John Kesapides8d942692019-02-26 14:52:12 +0000281 using GraphValidateExample::graph;
282
John Kesapides341b2182019-02-22 10:05:29 +0000283public:
John Kesapides8d942692019-02-26 14:52:12 +0000284 GraphFullyConnectedValidateExample()
285 : GraphValidateExample("Fully_connected Graph example")
John Kesapides341b2182019-02-22 10:05:29 +0000286 {
287 }
288
John Kesapides8d942692019-02-26 14:52:12 +0000289 FullyConnectedLayer GraphFunctionLayer(ExampleParams &params) override
290 {
291 const PixelValue lower = PixelValue(params.input.range_low, params.data_type, params.input.quant_info);
292 const PixelValue upper = PixelValue(params.input.range_high, params.data_type, params.input.quant_info);
293
294 const PixelValue weights_lower = PixelValue(params.weights.range_low, params.data_type, params.weights.quant_info);
295 const PixelValue weights_upper = PixelValue(params.weights.range_high, params.data_type, params.weights.quant_info);
296
297 return FullyConnectedLayer(params.fully_connected.num_outputs,
298 get_random_accessor(weights_lower, weights_upper, 1),
299 get_random_accessor(lower, upper, 2),
300 params.fully_connected.info, params.weights.quant_info, params.output.quant_info);
301 }
John Kesapides341b2182019-02-22 10:05:29 +0000302};
303
304/** Main program for Graph fully_connected test
305 *
306 * @param[in] argc Number of arguments
307 * @param[in] argv Arguments ( Input dimensions [width, batch]
308 * Fully connected [num_outputs,type]
309 * Verification[tolerance_number,absolute_tolerance,relative_tolerance] )
310 *
311 */
312int main(int argc, char **argv)
313{
John Kesapides8d942692019-02-26 14:52:12 +0000314 return arm_compute::utils::run_example<GraphFullyConnectedValidateExample>(argc, argv);
John Kesapides341b2182019-02-22 10:05:29 +0000315}