blob: dfa15edd6d58512ac9e6be2df9e51de5c2524a59 [file] [log] [blame]
John Kesapides341b2182019-02-22 10:05:29 +00001/*
2 * Copyright (c) 2019 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/graph.h"
25
26#include "support/ToolchainSupport.h"
27
28#include "tests/NEON/Accessor.h"
29#include "tests/validation/Validation.h"
30#include "tests/validation/reference/FullyConnectedLayer.h"
31#include "tests/validation/reference/Permute.h"
32
33#include "utils/CommonGraphOptions.h"
34#include "utils/GraphUtils.h"
35#include "utils/Utils.h"
36
37#include "ValidateExample.h"
John Kesapides8d942692019-02-26 14:52:12 +000038#include "graph_validate_utils.h"
John Kesapides341b2182019-02-22 10:05:29 +000039
40#include <utility>
41
42using namespace arm_compute::utils;
43using namespace arm_compute::graph::frontend;
44using namespace arm_compute::graph_utils;
45using namespace arm_compute::graph;
46using namespace arm_compute;
47using namespace arm_compute::test;
48using namespace arm_compute::test::validation;
John Kesapides8d942692019-02-26 14:52:12 +000049
John Kesapides341b2182019-02-22 10:05:29 +000050namespace
51{
John Kesapides8d942692019-02-26 14:52:12 +000052/** Fully connected command line options used to configure the graph examples
John Kesapides341b2182019-02-22 10:05:29 +000053 *
54 * (Similar to common options)
55 * The options in this object get populated when "parse()" is called on the parser used to construct it.
56 * The expected workflow is:
57 *
58 * CommandLineParser parser;
59 * CommonOptions options( parser );
60 * parser.parse(argc, argv);
61 */
John Kesapides8d942692019-02-26 14:52:12 +000062class FullyConnectedOptions final : public CommonGraphValidateOptions
John Kesapides341b2182019-02-22 10:05:29 +000063{
64public:
65 explicit FullyConnectedOptions(CommandLineParser &parser) noexcept
John Kesapides8d942692019-02-26 14:52:12 +000066 : CommonGraphValidateOptions(parser),
67 width(parser.add_option<SimpleOption<int>>("width", 3)),
John Kesapides341b2182019-02-22 10:05:29 +000068 batch(parser.add_option<SimpleOption<int>>("batch", 1)),
John Kesapides341b2182019-02-22 10:05:29 +000069 input_scale(parser.add_option<SimpleOption<float>>("input_scale", 1.0f)),
70 input_offset(parser.add_option<SimpleOption<int>>("input_offset", 0)),
71 weights_scale(parser.add_option<SimpleOption<float>>("weights_scale", 1.0f)),
72 weights_offset(parser.add_option<SimpleOption<int>>("weights_offset", 0)),
73 output_scale(parser.add_option<SimpleOption<float>>("output_scale", 1.0f)),
74 output_offset(parser.add_option<SimpleOption<int>>("output_offset", 0)),
75 num_outputs(parser.add_option<SimpleOption<int>>("num_outputs", 1)),
76 input_range_low(parser.add_option<SimpleOption<uint64_t>>("input_range_low")),
77 input_range_high(parser.add_option<SimpleOption<uint64_t>>("input_range_high")),
78 weights_range_low(parser.add_option<SimpleOption<uint64_t>>("weights_range_low")),
79 weights_range_high(parser.add_option<SimpleOption<uint64_t>>("weights_range_high"))
80 {
John Kesapides341b2182019-02-22 10:05:29 +000081 width->set_help("Set Input dimension width");
82 batch->set_help("Set Input dimension batch");
John Kesapides341b2182019-02-22 10:05:29 +000083 input_scale->set_help("Quantization scale from QASYMM8");
84 input_offset->set_help("Quantization offset from QASYMM8");
85 weights_scale->set_help("Quantization scale from QASYMM8");
86 weights_offset->set_help("Quantization offset from QASYMM8");
87 output_scale->set_help("Quantization scale from QASYMM8");
88 output_offset->set_help("Quantization offset from QASYMM8");
89 num_outputs->set_help("Number of outputs.");
90 input_range_low->set_help("Lower bound for input randomization range");
91 input_range_high->set_help("Lower bound for input randomization range");
92 weights_range_low->set_help("Lower bound for input randomization range");
93 weights_range_high->set_help("Lower bound for input randomization range");
94 }
95
John Kesapides8d942692019-02-26 14:52:12 +000096 /** Fill out the supplied parameters with user supplied parameters
97 *
98 * @param[out] os Output stream.
99 * @param[in] common_params Example parameters to output
100 *
101 * @return None.
102 */
103 void consume_parameters(ExampleParams &common_params)
104 {
105 common_params.input.width = width->value();
106 common_params.input.batch = batch->value();
107 common_params.input.quant_info.scale = input_scale->value();
108 common_params.input.quant_info.offset = input_offset->value();
109 common_params.input.range_low = input_range_low->value();
110 common_params.input.range_high = input_range_high->value();
111
112 common_params.weights.quant_info.scale = weights_scale->value();
113 common_params.weights.quant_info.offset = weights_offset->value();
114 common_params.weights.range_low = weights_range_low->value();
115 common_params.weights.range_high = weights_range_high->value();
116
117 common_params.output.quant_info.scale = output_scale->value();
118 common_params.output.quant_info.offset = output_offset->value();
119
120 common_params.data_type = data_type->value();
121 common_params.fully_connected.num_outputs = num_outputs->value();
122 }
123
124 void print_parameters(::std::ostream &os, const ExampleParams &common_params) override
125 {
126 os << "Threads : " << common_params.common_params.threads << std::endl;
127 os << "Target : " << common_params.common_params.target << std::endl;
128 os << "Data type : " << common_params.data_type << std::endl;
129 os << "Input dimensions(X,Y, Channels, Batch) : (" << common_params.input.width << "," << common_params.input.height << "," << common_params.input.fm << "," << common_params.input.batch << ")"
130 << std::endl;
131 os << "Number of outputs : " << common_params.fully_connected.num_outputs << std::endl;
132 }
133
John Kesapides341b2182019-02-22 10:05:29 +0000134 /** Prevent instances of this class from being copied (As this class contains pointers) */
135 FullyConnectedOptions(const FullyConnectedOptions &) = delete;
136 /** Prevent instances of this class from being copied (As this class contains pointers) */
137 FullyConnectedOptions &operator=(const FullyConnectedOptions &) = delete;
138 /** Allow instances of this class to be moved */
139 FullyConnectedOptions(FullyConnectedOptions &&) noexcept(true) = default;
140 /** Allow instances of this class to be moved */
141 FullyConnectedOptions &operator=(FullyConnectedOptions &&) noexcept(true) = default;
142 /** Default destructor */
John Kesapides8d942692019-02-26 14:52:12 +0000143 ~FullyConnectedOptions() override = default;
John Kesapides341b2182019-02-22 10:05:29 +0000144
Michalis Spyrou299fdd32019-05-01 13:03:59 +0100145private:
John Kesapides8d942692019-02-26 14:52:12 +0000146 SimpleOption<int> *width; /**< Input width */
147 SimpleOption<int> *batch; /**< Input batch */
148 SimpleOption<float> *input_scale; /**< Input Quantization scale from QASSYMM8 */
149 SimpleOption<int> *input_offset; /**< Input Quantization offset from QASSYMM8 */
150 SimpleOption<float> *weights_scale; /**< Weights Quantization scale from QASSYMM8 */
151 SimpleOption<int> *weights_offset; /**< Weights Quantization offset from QASSYMM8 */
152 SimpleOption<float> *output_scale; /**< Output Quantization scale from QASSYMM8 */
153 SimpleOption<int> *output_offset; /**< Output Quantization offset from QASSYMM8 */
154 SimpleOption<int> *num_outputs; /**< Number of outputs. */
155 SimpleOption<uint64_t> *input_range_low; /**< Lower bound for input randomization range */
156 SimpleOption<uint64_t> *input_range_high; /**< Upper bound for input randomization range */
157 SimpleOption<uint64_t> *weights_range_low; /**< Lower bound for weights randomization range */
158 SimpleOption<uint64_t> *weights_range_high; /**< Upper bound for weights randomization range */
John Kesapides341b2182019-02-22 10:05:29 +0000159};
160
John Kesapides8d942692019-02-26 14:52:12 +0000161/** Fully Connected Layer Graph example validation accessor class */
John Kesapides341b2182019-02-22 10:05:29 +0000162template <typename D>
John Kesapides8d942692019-02-26 14:52:12 +0000163class FullyConnectedVerifyAccessor final : public VerifyAccessor<D>
John Kesapides341b2182019-02-22 10:05:29 +0000164{
John Kesapides8d942692019-02-26 14:52:12 +0000165 using BaseClassType = VerifyAccessor<D>;
166 using BaseClassType::BaseClassType;
167 using BaseClassType::_params;
John Kesapides341b2182019-02-22 10:05:29 +0000168 using TBias = typename std::conditional<std::is_same<typename std::decay<D>::type, uint8_t>::value, int32_t, D>::type;
169
John Kesapides8d942692019-02-26 14:52:12 +0000170 // Inherited methods overriden:
171 void create_tensors(arm_compute::test::SimpleTensor<D> &src,
172 arm_compute::test::SimpleTensor<D> &weights,
173 arm_compute::test::SimpleTensor<TBias> &bias,
174 ITensor &tensor) override
John Kesapides341b2182019-02-22 10:05:29 +0000175 {
John Kesapides341b2182019-02-22 10:05:29 +0000176 // Calculate Tensor shapes for verification
177 const TensorShape input_shape = TensorShape(_params.input.width, _params.input.height, _params.input.fm, _params.input.batch);
John Kesapides8d942692019-02-26 14:52:12 +0000178 const TensorDescriptor input_descriptor = TensorDescriptor(input_shape, _params.data_type, _params.input.quant_info);
John Kesapides341b2182019-02-22 10:05:29 +0000179 const TensorDescriptor weights_descriptor = FullyConnectedLayerNode::compute_weights_descriptor(input_descriptor,
180 _params.fully_connected.num_outputs,
181 _params.fully_connected.info,
182 _params.weights.quant_info);
183 const TensorDescriptor output_desciptor = FullyConnectedLayerNode::compute_output_descriptor(input_descriptor, _params.fully_connected.num_outputs, _params.output.quant_info);
184
185 //Create Input tensors
John Kesapides8d942692019-02-26 14:52:12 +0000186 src = SimpleTensor<D> { input_descriptor.shape, _params.data_type, 1, input_descriptor.quant_info };
187 weights = SimpleTensor<D> { weights_descriptor.shape, _params.data_type, 1, weights_descriptor.quant_info };
188 bias = SimpleTensor<TBias> { TensorShape(tensor.info()->tensor_shape().x()), _params.data_type, 1, _params.input.quant_info };
John Kesapides341b2182019-02-22 10:05:29 +0000189 }
190
John Kesapides8d942692019-02-26 14:52:12 +0000191 TensorShape output_shape(ITensor &tensor) override
John Kesapides341b2182019-02-22 10:05:29 +0000192 {
John Kesapides8d942692019-02-26 14:52:12 +0000193 ARM_COMPUTE_UNUSED(tensor);
John Kesapides341b2182019-02-22 10:05:29 +0000194
John Kesapides8d942692019-02-26 14:52:12 +0000195 const TensorShape input_shape = TensorShape(_params.input.width, _params.input.height, _params.input.fm, _params.input.batch);
196 const TensorDescriptor input_descriptor = TensorDescriptor(input_shape, _params.data_type, _params.input.quant_info);
197 const TensorDescriptor output_desciptor = FullyConnectedLayerNode::compute_output_descriptor(input_descriptor, _params.fully_connected.num_outputs, _params.output.quant_info);
John Kesapides341b2182019-02-22 10:05:29 +0000198
John Kesapides8d942692019-02-26 14:52:12 +0000199 return output_desciptor.shape;
John Kesapides341b2182019-02-22 10:05:29 +0000200 }
John Kesapides8d942692019-02-26 14:52:12 +0000201
202 arm_compute::test::SimpleTensor<D> reference(arm_compute::test::SimpleTensor<D> &src,
203 arm_compute::test::SimpleTensor<D> &weights,
204 arm_compute::test::SimpleTensor<TBias> &bias,
205 const arm_compute::TensorShape &output_shape) override
206 {
207 return reference::fully_connected_layer<D>(src, weights, bias, output_shape, _params.output.quant_info);
208 }
209
210 float relative_tolerance() override
John Kesapides341b2182019-02-22 10:05:29 +0000211 {
212 const std::map<arm_compute::graph::Target, const std::map<DataType, float>> relative_tolerance
213 {
214 {
215 arm_compute::graph::Target::CL,
216 { { DataType::F16, 0.2f },
217 { DataType::F32, 0.05f },
218 { DataType::QASYMM8, 1.0f }
219 }
220 },
221 {
222 arm_compute::graph::Target::NEON,
223 { { DataType::F16, 0.2f },
224 { DataType::F32, 0.01f },
225 { DataType::QASYMM8, 1.0f }
226 }
227 }
228 };
John Kesapides341b2182019-02-22 10:05:29 +0000229
John Kesapides8d942692019-02-26 14:52:12 +0000230 return relative_tolerance.at(_params.common_params.target).at(_params.data_type);
John Kesapides341b2182019-02-22 10:05:29 +0000231 }
232
John Kesapides8d942692019-02-26 14:52:12 +0000233 float absolute_tolerance() override
John Kesapides341b2182019-02-22 10:05:29 +0000234 {
235 const std::map<Target, const std::map<DataType, float>> absolute_tolerance
236 {
237 {
238 Target::CL,
239 { { DataType::F16, 0.0f },
240 { DataType::F32, 0.0001f },
241 { DataType::QASYMM8, 1.0f }
242 }
243 },
244 {
245 Target::NEON,
246 { { DataType::F16, 0.3f },
247 { DataType::F32, 0.1f },
248 { DataType::QASYMM8, 1.0f }
249 }
250 }
251 };
252
John Kesapides8d942692019-02-26 14:52:12 +0000253 return absolute_tolerance.at(_params.common_params.target).at(_params.data_type);
John Kesapides341b2182019-02-22 10:05:29 +0000254 }
John Kesapides8d942692019-02-26 14:52:12 +0000255
256 float tolerance_number() override
John Kesapides341b2182019-02-22 10:05:29 +0000257 {
258 const std::map<Target, const std::map<DataType, float>> absolute_tolerance
259 {
260 {
261 Target::CL,
262 { { DataType::F16, 0.07f },
263 { DataType::F32, 0.07f },
264 { DataType::QASYMM8, 0.0f }
265 }
266 },
267 {
268 Target::NEON,
269 { { DataType::F16, 0.07f },
270 { DataType::F32, 0.0f },
271 { DataType::QASYMM8, 0.0f }
272 }
273 }
274 };
275
John Kesapides8d942692019-02-26 14:52:12 +0000276 return absolute_tolerance.at(_params.common_params.target).at(_params.data_type);
John Kesapides341b2182019-02-22 10:05:29 +0000277 }
John Kesapides341b2182019-02-22 10:05:29 +0000278};
279
John Kesapides341b2182019-02-22 10:05:29 +0000280} // namespace
281
John Kesapides8d942692019-02-26 14:52:12 +0000282class GraphFullyConnectedValidateExample final : public GraphValidateExample<FullyConnectedLayer, FullyConnectedOptions, FullyConnectedVerifyAccessor>
John Kesapides341b2182019-02-22 10:05:29 +0000283{
John Kesapides8d942692019-02-26 14:52:12 +0000284 using GraphValidateExample::graph;
285
John Kesapides341b2182019-02-22 10:05:29 +0000286public:
John Kesapides8d942692019-02-26 14:52:12 +0000287 GraphFullyConnectedValidateExample()
288 : GraphValidateExample("Fully_connected Graph example")
John Kesapides341b2182019-02-22 10:05:29 +0000289 {
290 }
291
John Kesapides8d942692019-02-26 14:52:12 +0000292 FullyConnectedLayer GraphFunctionLayer(ExampleParams &params) override
293 {
294 const PixelValue lower = PixelValue(params.input.range_low, params.data_type, params.input.quant_info);
295 const PixelValue upper = PixelValue(params.input.range_high, params.data_type, params.input.quant_info);
296
297 const PixelValue weights_lower = PixelValue(params.weights.range_low, params.data_type, params.weights.quant_info);
298 const PixelValue weights_upper = PixelValue(params.weights.range_high, params.data_type, params.weights.quant_info);
299
300 return FullyConnectedLayer(params.fully_connected.num_outputs,
301 get_random_accessor(weights_lower, weights_upper, 1),
302 get_random_accessor(lower, upper, 2),
303 params.fully_connected.info, params.weights.quant_info, params.output.quant_info);
304 }
John Kesapides341b2182019-02-22 10:05:29 +0000305};
306
307/** Main program for Graph fully_connected test
308 *
309 * @param[in] argc Number of arguments
310 * @param[in] argv Arguments ( Input dimensions [width, batch]
311 * Fully connected [num_outputs,type]
312 * Verification[tolerance_number,absolute_tolerance,relative_tolerance] )
313 *
314 */
315int main(int argc, char **argv)
316{
John Kesapides8d942692019-02-26 14:52:12 +0000317 return arm_compute::utils::run_example<GraphFullyConnectedValidateExample>(argc, argv);
John Kesapides341b2182019-02-22 10:05:29 +0000318}