blob: 2db6abc9d6802a1a9684641bafd5d1e2a0927dde [file] [log] [blame]
Giorgio Arena16def8d2021-10-07 11:03:12 +01001/*
2 * Copyright (c) 2021 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/utils/misc/ShapeCalculator.h"
25#include "tests/framework/Fixture.h"
26#include "tests/validation/reference/ActivationLayer.h"
27#include "tests/validation/reference/Conv3D.h"
28
29#include <random>
30
31namespace arm_compute
32{
33namespace test
34{
35namespace validation
36{
37using namespace arm_compute::misc::shape_calculator;
38
39template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
40class DirectConvolution3DValidationGenericFixture : public framework::Fixture
41{
42public:
43 template <typename...>
44 void setup(TensorShape input_shape, int stride_x, int stride_y, int stride_z, int pad_x, int pad_y, int pad_z, unsigned int kernel_width, int kernel_height, int kernel_depth,
45 unsigned int num_kernels, bool has_bias, ActivationLayerInfo act_info, DataType data_type, DataLayout data_layout)
46 {
47 ARM_COMPUTE_ERROR_ON(data_layout != DataLayout::NDHWC);
48
49 TensorShape weights_shape(num_kernels, input_shape[0], kernel_width, kernel_height, kernel_depth);
50 const TensorShape bias_shape(num_kernels);
51 const Conv3dInfo conv3d_info(Size3D(stride_x, stride_y, stride_z), Padding3D(pad_x, pad_y, pad_z), act_info, Size3D(), DimensionRoundingType::FLOOR, false);
52 const TensorShape output_shape = compute_conv3d_shape(input_shape, weights_shape, conv3d_info);
53
54 _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, conv3d_info, has_bias, data_type, data_layout);
55 _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, conv3d_info, has_bias, data_type);
56 }
57
58protected:
59 template <typename U>
60 void fill(U &&tensor, int i)
61 {
62 switch(tensor.data_type())
63 {
64 case DataType::F16:
65 {
66 arm_compute::utils::uniform_real_distribution_16bit<half> distribution{ -1.0f, 1.0f };
67 library->fill(tensor, distribution, i);
68 break;
69 }
70 case DataType::F32:
71 {
72 std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
73 library->fill(tensor, distribution, i);
74 break;
75 }
76 default:
77 library->fill_tensor_uniform(tensor, i);
78 }
79 }
80
81 TensorType compute_target(TensorShape input_shape, TensorShape weights_shape, const TensorShape &bias_shape, TensorShape output_shape, const Conv3dInfo &conv3d_info,
82 bool has_bias, const DataType &data_type, const DataLayout &data_layout)
83 {
84 // Create tensors
85 TensorType src = create_tensor<TensorType>(input_shape, data_type, 1, QuantizationInfo(), data_layout);
86 TensorType weights = create_tensor<TensorType>(weights_shape, data_type, 1, QuantizationInfo(), data_layout);
87 TensorType bias = has_bias ? create_tensor<TensorType>(bias_shape, data_type, 1, QuantizationInfo()) : TensorType();
88 TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1, QuantizationInfo(), data_layout);
89
90 add_padding_x({ &src, &dst, &weights }, data_layout);
91
92 if(has_bias)
93 {
94 add_padding_x({ &bias }, data_layout);
95 }
96
97 // Create and configure function
98 FunctionType conv{};
99 conv.configure(&src, &weights, has_bias ? &bias : nullptr, &dst, conv3d_info);
100
101 ARM_COMPUTE_ASSERT(src.info()->is_resizable());
102 ARM_COMPUTE_ASSERT(weights.info()->is_resizable());
103 ARM_COMPUTE_ASSERT(dst.info()->is_resizable());
104
105 // Allocate tensors
106 src.allocator()->allocate();
107 weights.allocator()->allocate();
108 dst.allocator()->allocate();
109
110 ARM_COMPUTE_ASSERT(!src.info()->is_resizable());
111 ARM_COMPUTE_ASSERT(!weights.info()->is_resizable());
112 ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
113
114 // Fill tensors
115 fill(AccessorType(src), 0);
116 fill(AccessorType(weights), 1);
117
118 if(has_bias)
119 {
120 ARM_COMPUTE_ASSERT(bias.info()->is_resizable());
121 bias.allocator()->allocate();
122 ARM_COMPUTE_ASSERT(!bias.info()->is_resizable());
123 fill(AccessorType(bias), 2);
124 }
125
126 // Compute Direct Convolution 3D function
127 conv.run();
128
129 return dst;
130 }
131
132 SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const Conv3dInfo &conv3d_info,
133 bool has_bias, const DataType &data_type)
134 {
135 // Create reference
136 SimpleTensor<T> src{ input_shape, data_type };
137 SimpleTensor<T> weights{ weights_shape, data_type };
138 SimpleTensor<T> bias{ bias_shape, data_type };
139 SimpleTensor<T> dst{ output_shape, data_type };
140
141 // Fill reference
142 fill(src, 0);
143 fill(weights, 1);
144
145 if(has_bias)
146 {
147 fill(bias, 2);
148 }
149
150 return reference::activation_layer(reference::conv3d<T>(src, weights, bias, dst, conv3d_info), conv3d_info.act_info);
151 }
152
153 TensorType _target{};
154 SimpleTensor<T> _reference{};
155};
156
157template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
158class DirectConvolution3DValidationFixture : public DirectConvolution3DValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
159{
160public:
161 template <typename...>
162 void setup(TensorShape input_shape, int stride_x, int stride_y, int stride_z, int pad_x, int pad_y, int pad_z, unsigned int kernel_width, int kernel_height, int kernel_depth,
163 unsigned int num_kernels, bool has_bias, ActivationLayerInfo act_info, DataType data_type, DataLayout data_layout)
164 {
165 DirectConvolution3DValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, stride_x, stride_y, stride_z, pad_x, pad_y, pad_z, kernel_width, kernel_height,
166 kernel_depth, num_kernels, has_bias, act_info, data_type, data_layout);
167 }
168};
169} // namespace validation
170} // namespace test
171} // namespace arm_compute