blob: e80ad2f54ff90f3c79ae647c3d272a2577e09359 [file] [log] [blame]
Giorgio Arena16def8d2021-10-07 11:03:12 +01001/*
Matthew Bentham945b8da2023-07-12 11:54:59 +00002 * Copyright (c) 2021, 2023 Arm Limited.
Giorgio Arena16def8d2021-10-07 11:03:12 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Jakub Sujak8f4b3df2023-10-30 16:04:51 +000024
25#ifndef ACL_TESTS_VALIDATION_FIXTURES_DIRECTCONVOLUTION3DFIXTURE_H
26#define ACL_TESTS_VALIDATION_FIXTURES_DIRECTCONVOLUTION3DFIXTURE_H
27
Giorgio Arena16def8d2021-10-07 11:03:12 +010028#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Jakub Sujak8f4b3df2023-10-30 16:04:51 +000029#include "tests/framework/Asserts.h" // Required for ARM_COMPUTE_ASSERT
Giorgio Arena16def8d2021-10-07 11:03:12 +010030#include "tests/framework/Fixture.h"
31#include "tests/validation/reference/ActivationLayer.h"
32#include "tests/validation/reference/Conv3D.h"
33
34#include <random>
35
36namespace arm_compute
37{
38namespace test
39{
40namespace validation
41{
42using namespace arm_compute::misc::shape_calculator;
43
44template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
45class DirectConvolution3DValidationGenericFixture : public framework::Fixture
46{
47public:
Giorgio Arena51847d52021-10-19 15:45:57 +010048 using TBias = typename std::conditional < std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value, int32_t, T >::type;
49
Giorgio Arena945ae9e2021-10-13 11:13:04 +010050 void setup(const TensorShape &input_shape, int stride_x, int stride_y, int stride_z, int pad_x, int pad_y, int pad_z, unsigned int kernel_width, int kernel_height, int kernel_depth,
Giorgio Arena51847d52021-10-19 15:45:57 +010051 unsigned int num_kernels, bool has_bias, const ActivationLayerInfo &act_info, const DataType &data_type, const DataLayout &data_layout,
52 const QuantizationInfo &src_qinfo = QuantizationInfo(), const QuantizationInfo &weights_qinfo = QuantizationInfo(), const QuantizationInfo &dst_qinfo = QuantizationInfo())
Giorgio Arena16def8d2021-10-07 11:03:12 +010053 {
54 ARM_COMPUTE_ERROR_ON(data_layout != DataLayout::NDHWC);
55
Giorgio Arena945ae9e2021-10-13 11:13:04 +010056 const TensorShape weights_shape(num_kernels, input_shape[0], kernel_width, kernel_height, kernel_depth);
Giorgio Arena16def8d2021-10-07 11:03:12 +010057 const TensorShape bias_shape(num_kernels);
Giorgio Arena51847d52021-10-19 15:45:57 +010058 const DataType bias_data_type = is_data_type_quantized(data_type) ? DataType::S32 : data_type;
Giorgio Arena5c002ec2021-10-12 16:00:40 +010059 const Conv3dInfo conv3d_info(Size3D(stride_x, stride_y, stride_z), Padding3D(pad_x, pad_y, pad_z), act_info, Size3D(1U, 1U, 1U), DimensionRoundingType::FLOOR, false);
Giorgio Arena16def8d2021-10-07 11:03:12 +010060 const TensorShape output_shape = compute_conv3d_shape(input_shape, weights_shape, conv3d_info);
61
Giorgio Arena51847d52021-10-19 15:45:57 +010062 _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, conv3d_info, has_bias, data_type, bias_data_type, data_layout, src_qinfo, weights_qinfo, dst_qinfo);
63 _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, conv3d_info, has_bias, data_type, bias_data_type, src_qinfo, weights_qinfo, dst_qinfo);
Giorgio Arena16def8d2021-10-07 11:03:12 +010064 }
65
66protected:
67 template <typename U>
68 void fill(U &&tensor, int i)
69 {
70 switch(tensor.data_type())
71 {
72 case DataType::F16:
73 {
74 arm_compute::utils::uniform_real_distribution_16bit<half> distribution{ -1.0f, 1.0f };
75 library->fill(tensor, distribution, i);
76 break;
77 }
78 case DataType::F32:
79 {
80 std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
81 library->fill(tensor, distribution, i);
82 break;
83 }
84 default:
85 library->fill_tensor_uniform(tensor, i);
86 }
87 }
88
Giorgio Arena945ae9e2021-10-13 11:13:04 +010089 TensorType compute_target(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const Conv3dInfo &conv3d_info,
Giorgio Arena51847d52021-10-19 15:45:57 +010090 bool has_bias, const DataType &data_type, const DataType &bias_data_type, const DataLayout &data_layout, const QuantizationInfo &src_qinfo,
91 const QuantizationInfo &weights_qinfo, const QuantizationInfo &dst_qinfo)
Giorgio Arena16def8d2021-10-07 11:03:12 +010092 {
93 // Create tensors
Giorgio Arena51847d52021-10-19 15:45:57 +010094 TensorType src = create_tensor<TensorType>(input_shape, data_type, 1, src_qinfo, data_layout);
95 TensorType weights = create_tensor<TensorType>(weights_shape, data_type, 1, weights_qinfo, data_layout);
96 TensorType bias = has_bias ? create_tensor<TensorType>(bias_shape, bias_data_type, 1, QuantizationInfo()) : TensorType();
97 TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1, dst_qinfo, data_layout);
Giorgio Arena16def8d2021-10-07 11:03:12 +010098
Giorgio Arena16def8d2021-10-07 11:03:12 +010099 // Create and configure function
100 FunctionType conv{};
101 conv.configure(&src, &weights, has_bias ? &bias : nullptr, &dst, conv3d_info);
102
103 ARM_COMPUTE_ASSERT(src.info()->is_resizable());
104 ARM_COMPUTE_ASSERT(weights.info()->is_resizable());
105 ARM_COMPUTE_ASSERT(dst.info()->is_resizable());
106
107 // Allocate tensors
108 src.allocator()->allocate();
109 weights.allocator()->allocate();
110 dst.allocator()->allocate();
111
112 ARM_COMPUTE_ASSERT(!src.info()->is_resizable());
113 ARM_COMPUTE_ASSERT(!weights.info()->is_resizable());
114 ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
115
116 // Fill tensors
117 fill(AccessorType(src), 0);
118 fill(AccessorType(weights), 1);
119
120 if(has_bias)
121 {
122 ARM_COMPUTE_ASSERT(bias.info()->is_resizable());
123 bias.allocator()->allocate();
124 ARM_COMPUTE_ASSERT(!bias.info()->is_resizable());
125 fill(AccessorType(bias), 2);
126 }
127
128 // Compute Direct Convolution 3D function
129 conv.run();
130
131 return dst;
132 }
133
Giorgio Arena51847d52021-10-19 15:45:57 +0100134 SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape,
135 const Conv3dInfo &conv3d_info, bool has_bias, const DataType &data_type, const DataType &bias_data_type, const QuantizationInfo &src_qinfo,
136 const QuantizationInfo &weights_qinfo, const QuantizationInfo &dst_qinfo)
Giorgio Arena16def8d2021-10-07 11:03:12 +0100137 {
138 // Create reference
Giorgio Arena51847d52021-10-19 15:45:57 +0100139 SimpleTensor<T> src{ input_shape, data_type, 1, src_qinfo };
140 SimpleTensor<T> weights{ weights_shape, data_type, 1, weights_qinfo };
141 SimpleTensor<TBias> bias{ bias_shape, bias_data_type };
142 SimpleTensor<T> dst{ output_shape, data_type, 1, dst_qinfo };
Giorgio Arena16def8d2021-10-07 11:03:12 +0100143
144 // Fill reference
145 fill(src, 0);
146 fill(weights, 1);
147
148 if(has_bias)
149 {
150 fill(bias, 2);
151 }
152
Giorgio Arena51847d52021-10-19 15:45:57 +0100153 return reference::activation_layer(reference::conv3d<T, TBias>(src, weights, bias, dst, conv3d_info), conv3d_info.act_info);
Giorgio Arena16def8d2021-10-07 11:03:12 +0100154 }
155
156 TensorType _target{};
157 SimpleTensor<T> _reference{};
158};
159
160template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
161class DirectConvolution3DValidationFixture : public DirectConvolution3DValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
162{
163public:
Giorgio Arena16def8d2021-10-07 11:03:12 +0100164 void setup(TensorShape input_shape, int stride_x, int stride_y, int stride_z, int pad_x, int pad_y, int pad_z, unsigned int kernel_width, int kernel_height, int kernel_depth,
165 unsigned int num_kernels, bool has_bias, ActivationLayerInfo act_info, DataType data_type, DataLayout data_layout)
166 {
167 DirectConvolution3DValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, stride_x, stride_y, stride_z, pad_x, pad_y, pad_z, kernel_width, kernel_height,
168 kernel_depth, num_kernels, has_bias, act_info, data_type, data_layout);
169 }
170};
Giorgio Arena51847d52021-10-19 15:45:57 +0100171
172template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
173class DirectConvolution3DValidationQuantizedFixture : public DirectConvolution3DValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
174{
175public:
Giorgio Arena51847d52021-10-19 15:45:57 +0100176 void setup(TensorShape input_shape, int stride_x, int stride_y, int stride_z, int pad_x, int pad_y, int pad_z, unsigned int kernel_width, int kernel_height, int kernel_depth,
177 unsigned int num_kernels, bool has_bias, ActivationLayerInfo act_info, DataType data_type, DataLayout data_layout, QuantizationInfo src_qinfo, QuantizationInfo weights_qinfo,
178 QuantizationInfo dst_qinfo)
179 {
180 DirectConvolution3DValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, stride_x, stride_y, stride_z, pad_x, pad_y, pad_z, kernel_width, kernel_height,
181 kernel_depth, num_kernels, has_bias, act_info, data_type, data_layout, src_qinfo,
182 weights_qinfo, dst_qinfo);
183 }
184};
Giorgio Arena16def8d2021-10-07 11:03:12 +0100185} // namespace validation
186} // namespace test
187} // namespace arm_compute
Jakub Sujak8f4b3df2023-10-30 16:04:51 +0000188
189#endif // ACL_TESTS_VALIDATION_FIXTURES_DIRECTCONVOLUTION3DFIXTURE_H