blob: ccdd4439991a40bcf72eca9631d993de7f08052e [file] [log] [blame]
Giorgio Arena93a690e2017-08-01 16:09:33 +01001/*
Georgios Pinitasf72f9362018-01-12 16:29:45 +00002 * Copyright (c) 2017-2018 ARM Limited.
Giorgio Arena93a690e2017-08-01 16:09:33 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef ARM_COMPUTE_TEST_DEPTHWISE_CONVOLUTION_FIXTURE
25#define ARM_COMPUTE_TEST_DEPTHWISE_CONVOLUTION_FIXTURE
26
27#include "arm_compute/core/TensorShape.h"
28#include "arm_compute/core/Types.h"
Giorgio Arena93a690e2017-08-01 16:09:33 +010029#include "tests/AssetsLibrary.h"
30#include "tests/Globals.h"
31#include "tests/IAccessor.h"
Moritz Pflanzera09de0c2017-09-01 20:41:12 +010032#include "tests/framework/Asserts.h"
33#include "tests/framework/Fixture.h"
Moritz Pflanzera09de0c2017-09-01 20:41:12 +010034#include "tests/validation/Helpers.h"
Georgios Pinitas5a7e7762017-12-01 16:27:29 +000035#include "tests/validation/reference/DepthwiseConvolutionLayer.h"
Giorgio Arena93a690e2017-08-01 16:09:33 +010036
Georgios Pinitasf72f9362018-01-12 16:29:45 +000037#include "utils/Utils.h"
38
Giorgio Arena93a690e2017-08-01 16:09:33 +010039#include <random>
40
41namespace arm_compute
42{
43namespace test
44{
45namespace validation
46{
47template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
Giorgio Arena04a8f8c2017-11-23 11:45:24 +000048class DepthwiseConvolutionLayerValidationGenericFixture : public framework::Fixture
Giorgio Arena93a690e2017-08-01 16:09:33 +010049{
50public:
Dmitry Savenkod7295b72017-11-20 22:00:08 +070051 using TBias = typename std::conditional<std::is_same<typename std::decay<T>::type, uint8_t>::value, int32_t, T>::type;
52
53public:
Giorgio Arena93a690e2017-08-01 16:09:33 +010054 template <typename...>
Giorgio Arena1ed1fc62018-03-26 16:20:05 +010055 void setup(TensorShape in_shape, TensorShape weights_shape, TensorShape out_shape, PadStrideInfo pad_stride_info, DataType data_type, QuantizationInfo quantization_info, DataLayout data_layout)
Giorgio Arena93a690e2017-08-01 16:09:33 +010056 {
Dmitry Savenkod7295b72017-11-20 22:00:08 +070057 _quantization_info = quantization_info;
58 _data_type = data_type;
Pablo Tello941cd702017-12-12 14:35:00 +000059 const TensorShape biases_shape(weights_shape[2]);
60 const DataType bias_data_type = is_data_type_quantized_asymmetric(data_type) ? DataType::S32 : data_type;
Dmitry Savenkod7295b72017-11-20 22:00:08 +070061
Giorgio Arena1ed1fc62018-03-26 16:20:05 +010062 if(data_layout == DataLayout::NHWC)
63 {
64 permute(in_shape, PermutationVector(2U, 0U, 1U));
65 permute(weights_shape, PermutationVector(2U, 0U, 1U));
66 permute(out_shape, PermutationVector(2U, 0U, 1U));
67 }
68
69 _target = compute_target(in_shape, weights_shape, biases_shape, out_shape, pad_stride_info, data_type, bias_data_type, quantization_info, data_layout);
70 _reference = compute_reference(in_shape, weights_shape, biases_shape, out_shape, pad_stride_info, data_type, bias_data_type, quantization_info, data_layout);
Giorgio Arena93a690e2017-08-01 16:09:33 +010071 }
72
73protected:
74 template <typename U>
75 void fill(U &&tensor, int i)
76 {
77 switch(tensor.data_type())
78 {
Dmitry Savenkod7295b72017-11-20 22:00:08 +070079 case DataType::QASYMM8:
80 {
81 std::uniform_int_distribution<uint8_t> distribution(0, 10);
82 library->fill(tensor, distribution, i);
83 break;
84 }
Giorgio Arena93a690e2017-08-01 16:09:33 +010085 case DataType::F32:
Frank Lei8cdfdb82018-01-02 16:49:33 +080086 case DataType::F16:
Giorgio Arena93a690e2017-08-01 16:09:33 +010087 {
88 std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
89 library->fill(tensor, distribution, i);
90 break;
91 }
Dmitry Savenkod7295b72017-11-20 22:00:08 +070092 case DataType::S32:
93 {
Georgios Pinitasf72f9362018-01-12 16:29:45 +000094 std::uniform_int_distribution<int32_t> distribution(-100, 100);
Dmitry Savenkod7295b72017-11-20 22:00:08 +070095 library->fill(tensor, distribution, i);
96 break;
97 }
Giorgio Arena93a690e2017-08-01 16:09:33 +010098 default:
99 library->fill_tensor_uniform(tensor, i);
100 }
101 }
102
Dmitry Savenkod7295b72017-11-20 22:00:08 +0700103 TensorType compute_target(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &biases_shape, const TensorShape &output_shape, PadStrideInfo &pad_stride_info,
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100104 const DataType data_type, const DataType bias_data_type, const QuantizationInfo quantization_info, const DataLayout data_layout)
Giorgio Arena93a690e2017-08-01 16:09:33 +0100105 {
106 // Create tensors
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100107 TensorType src = create_tensor<TensorType>(input_shape, data_type, 1, 0, quantization_info, data_layout);
108 TensorType weights = create_tensor<TensorType>(weights_shape, data_type, 1, 0, quantization_info, data_layout);
109 TensorType biases = create_tensor<TensorType>(biases_shape, bias_data_type, 1, 0, quantization_info, data_layout);
110 TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1, 0, quantization_info, data_layout);
Giorgio Arena93a690e2017-08-01 16:09:33 +0100111
112 // Create Depthwise Convolution configure function
Dmitry Savenkod7295b72017-11-20 22:00:08 +0700113 FunctionType dwc;
114 dwc.configure(&src, &weights, &biases, &dst, pad_stride_info);
115
116 ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
117 ARM_COMPUTE_EXPECT(weights.info()->is_resizable(), framework::LogLevel::ERRORS);
118 ARM_COMPUTE_EXPECT(biases.info()->is_resizable(), framework::LogLevel::ERRORS);
119 ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
Giorgio Arena93a690e2017-08-01 16:09:33 +0100120
121 // Allocate tensors
122 src.allocator()->allocate();
123 weights.allocator()->allocate();
Georgios Pinitas81a26ad2017-10-23 20:29:30 +0100124 biases.allocator()->allocate();
Giorgio Arena93a690e2017-08-01 16:09:33 +0100125 dst.allocator()->allocate();
126
127 ARM_COMPUTE_EXPECT(!src.info()->is_resizable(), framework::LogLevel::ERRORS);
128 ARM_COMPUTE_EXPECT(!weights.info()->is_resizable(), framework::LogLevel::ERRORS);
Georgios Pinitas81a26ad2017-10-23 20:29:30 +0100129 ARM_COMPUTE_EXPECT(!biases.info()->is_resizable(), framework::LogLevel::ERRORS);
Giorgio Arena93a690e2017-08-01 16:09:33 +0100130 ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS);
131
132 // Fill tensors
133 fill(AccessorType(src), 0);
134 fill(AccessorType(weights), 1);
Georgios Pinitas81a26ad2017-10-23 20:29:30 +0100135 fill(AccessorType(biases), 2);
Giorgio Arena93a690e2017-08-01 16:09:33 +0100136
137 // Compute function
Dmitry Savenkod7295b72017-11-20 22:00:08 +0700138 dwc.run();
Giorgio Arena93a690e2017-08-01 16:09:33 +0100139
140 return dst;
141 }
142
Dmitry Savenkod7295b72017-11-20 22:00:08 +0700143 SimpleTensor<T> compute_reference(const TensorShape &in_shape, const TensorShape &weights_shape, const TensorShape &biases_shape, const TensorShape &out_shape, const PadStrideInfo &pad_stride_info,
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100144 const DataType data_type, const DataType bias_data_type, const QuantizationInfo quantization_info, const DataLayout data_layout)
Giorgio Arena93a690e2017-08-01 16:09:33 +0100145 {
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100146 SimpleTensor<T> src{ in_shape, data_type, 1, 0, quantization_info, data_layout };
147 SimpleTensor<T> weights{ weights_shape, data_type, 1, 0, quantization_info, data_layout };
148 SimpleTensor<TBias> biases{ biases_shape, bias_data_type, 1, 0, quantization_info, data_layout };
Giorgio Arena93a690e2017-08-01 16:09:33 +0100149
150 fill(src, 0);
151 fill(weights, 1);
Georgios Pinitas81a26ad2017-10-23 20:29:30 +0100152 fill(biases, 2);
Giorgio Arena93a690e2017-08-01 16:09:33 +0100153
Georgios Pinitas81a26ad2017-10-23 20:29:30 +0100154 return reference::depthwise_convolution(src, weights, biases, out_shape, pad_stride_info);
Giorgio Arena93a690e2017-08-01 16:09:33 +0100155 }
156
Dmitry Savenkod7295b72017-11-20 22:00:08 +0700157 TensorType _target{};
158 SimpleTensor<T> _reference{};
159 DataType _data_type{};
160 QuantizationInfo _quantization_info{};
161};
162
163template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
Giorgio Arena04a8f8c2017-11-23 11:45:24 +0000164class DepthwiseConvolutionLayerValidationFixture : public DepthwiseConvolutionLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
Dmitry Savenkod7295b72017-11-20 22:00:08 +0700165{
166public:
167 template <typename...>
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100168 void setup(TensorShape in_shape, TensorShape weights_shape, TensorShape out_shape, PadStrideInfo pad_stride_info, DataType data_type, DataLayout data_layout)
Dmitry Savenkod7295b72017-11-20 22:00:08 +0700169 {
Pablo Tello941cd702017-12-12 14:35:00 +0000170 DepthwiseConvolutionLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(in_shape, weights_shape, out_shape, pad_stride_info,
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100171 data_type, QuantizationInfo(), data_layout);
Dmitry Savenkod7295b72017-11-20 22:00:08 +0700172 }
173};
174
175template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
Giorgio Arena04a8f8c2017-11-23 11:45:24 +0000176class DepthwiseConvolutionLayerValidationQuantizedFixture : public DepthwiseConvolutionLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
Dmitry Savenkod7295b72017-11-20 22:00:08 +0700177{
178public:
179 template <typename...>
Pablo Tello941cd702017-12-12 14:35:00 +0000180 void setup(TensorShape in_shape, TensorShape weights_shape, TensorShape out_shape, PadStrideInfo pad_stride_info, DataType data_type, QuantizationInfo quantization_info)
Dmitry Savenkod7295b72017-11-20 22:00:08 +0700181 {
Pablo Tello941cd702017-12-12 14:35:00 +0000182 DepthwiseConvolutionLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(in_shape, weights_shape, out_shape, pad_stride_info,
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100183 data_type, quantization_info, DataLayout::NCHW);
Dmitry Savenkod7295b72017-11-20 22:00:08 +0700184 }
Giorgio Arena93a690e2017-08-01 16:09:33 +0100185};
186} // namespace validation
187} // namespace test
188} // namespace arm_compute
189#endif /* ARM_COMPUTE_TEST_DEPTHWISE_CONVOLUTION_FIXTURE */