blob: bb4a1cd7beb0119d3f77101d396ce905d58947dc [file] [log] [blame]
Ramy Elgammalf26ea2f2023-03-24 11:42:03 +00001/*
2 * Copyright (c) 2023 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef TESTS_VALIDATION_FIXTURES_MATMULFIXTURE
25#define TESTS_VALIDATION_FIXTURES_MATMULFIXTURE
26
27#include "arm_compute/core/Types.h"
28#include "tests/framework/Fixture.h"
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +000029#include "tests/validation/reference/ActivationLayer.h"
Ramy Elgammalf26ea2f2023-03-24 11:42:03 +000030#include "tests/validation/reference/GEMM.h"
31#include "tests/validation/reference/Permute.h"
Ramy Elgammalf26ea2f2023-03-24 11:42:03 +000032#include "tests/validation/reference/ReshapeLayer.h"
33#include <random>
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +000034
Ramy Elgammalf26ea2f2023-03-24 11:42:03 +000035namespace arm_compute
36{
37namespace test
38{
39namespace validation
40{
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +000041template <typename TensorType, typename AccessorType, typename FunctionType, typename Settings, typename T>
42class MatMulGenericValidationFixture : public framework::Fixture
Ramy Elgammalf26ea2f2023-03-24 11:42:03 +000043{
44public:
45 template <typename...>
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +000046 void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool transpose_a, bool transpose_b, DataType data_type, ActivationLayerInfo act_info, int num_extra_runs,
47 Settings settings)
Ramy Elgammalf26ea2f2023-03-24 11:42:03 +000048 {
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +000049 // For brevity, the input shapes are assumed to be not-transposed for both a and b matrices.
50 if(transpose_a)
Ramy Elgammalf26ea2f2023-03-24 11:42:03 +000051 {
52 permute(shape_a, PermutationVector(1U, 0U));
53 }
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +000054 if(transpose_b)
Ramy Elgammalf26ea2f2023-03-24 11:42:03 +000055 {
56 permute(shape_b, PermutationVector(1U, 0U));
57 }
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +000058
59 _target = compute_target(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, num_extra_runs, settings);
60 _reference = compute_reference(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info);
Ramy Elgammalf26ea2f2023-03-24 11:42:03 +000061 }
62
63protected:
64 template <typename U>
65 void fill(U &&tensor, int i, float lo = -1.f, float hi = 1.f)
66 {
67 switch(tensor.data_type())
68 {
69 case DataType::F16:
70 {
71 arm_compute::utils::uniform_real_distribution_16bit<half> distribution{ float(lo), float(hi) };
72 library->fill(tensor, distribution, i);
73 break;
74 }
75 case DataType::F32:
76 {
77 std::uniform_real_distribution<float> distribution(lo, hi);
78 library->fill(tensor, distribution, i);
79 break;
80 }
81 default:
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +000082 {
Ramy Elgammalf26ea2f2023-03-24 11:42:03 +000083 library->fill_tensor_uniform(tensor, i);
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +000084 }
Ramy Elgammalf26ea2f2023-03-24 11:42:03 +000085 }
86 }
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +000087
88 TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &output_shape, bool transpose_a, bool transpose_b, DataType data_type,
89 ActivationLayerInfo act_info, int num_extra_runs, const Settings &settings)
Ramy Elgammalf26ea2f2023-03-24 11:42:03 +000090 {
91 // 1. Create Classes and configure function
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +000092 // ----------------------------------------------------
Ramy Elgammalf26ea2f2023-03-24 11:42:03 +000093 // Create tensors
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +000094 // Configure relevant classes and matmul function
95 TensorType a = create_tensor<TensorType>(shape_a, data_type, 1);
96 TensorType b = create_tensor<TensorType>(shape_b, data_type, 1);
97 TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1);
98
Ramy Elgammalf26ea2f2023-03-24 11:42:03 +000099 FunctionType matmul;
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +0000100
Ramy Elgammalf26ea2f2023-03-24 11:42:03 +0000101 // Configure MatMulInfo class
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +0000102 MatMulInfo mm_info;
103 mm_info.adj_lhs(transpose_a).adj_rhs(transpose_b).fused_activation(act_info);
104
105 // Ensure values are dynamic
106 a.info()->set_are_values_constant(false);
107 b.info()->set_are_values_constant(false);
108
109 // Configure operator
110 matmul.configure(&a, &b, &dst, mm_info, settings);
111
Ramy Elgammalf26ea2f2023-03-24 11:42:03 +0000112 // Assertions
113 ARM_COMPUTE_ASSERT(a.info()->is_resizable());
114 ARM_COMPUTE_ASSERT(b.info()->is_resizable());
115 ARM_COMPUTE_ASSERT(dst.info()->is_resizable());
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +0000116
Ramy Elgammalf26ea2f2023-03-24 11:42:03 +0000117 // Allocate tensors
118 a.allocator()->allocate();
119 b.allocator()->allocate();
120 dst.allocator()->allocate();
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +0000121
Ramy Elgammalf26ea2f2023-03-24 11:42:03 +0000122 ARM_COMPUTE_ASSERT(!a.info()->is_resizable());
123 ARM_COMPUTE_ASSERT(!b.info()->is_resizable());
124 ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
125
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +0000126 // For multiple runs.
127 for(int i = 0; i < num_extra_runs; i++)
128 {
129 // Stress dynamic tensors by running multiple times.
130 // --------------------------------------------------------
131 // Fill tensors with new seed
132 // Run function
133 const int seed_offset = num_extra_runs * 100;
134 fill(AccessorType(a), seed_offset);
135 fill(AccessorType(b), seed_offset + 1);
136
137 matmul.run();
138 }
139
140 // 2. Final Run for reference comparison
141 // --------------------------------------------------------
142 // Re-fill tensors same seed as reference run
143 // Compute MatMul operation
144 fill(AccessorType(a), 2);
145 fill(AccessorType(b), 3);
146
147 matmul.run();
Ramy Elgammalf26ea2f2023-03-24 11:42:03 +0000148
149 return dst;
150 }
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +0000151
152 SimpleTensor<T> compute_reference(const TensorShape &a_shape, const TensorShape &b_shape, const TensorShape &output_shape, bool transpose_a, bool transpose_b, DataType data_type,
153 ActivationLayerInfo act_info)
Ramy Elgammalf26ea2f2023-03-24 11:42:03 +0000154 {
155 // We collapse dimensions > 3 onto dimension 3, i.e. 5D+ tensors will look like 4D
156 // This is necessary unless we choose to extend gemm reference for 5D+ tensors
157 TensorShape output_shape_collapsed = output_shape.collapsed_from(Window::DimW);
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +0000158 TensorShape a_shape_collapsed = a_shape.collapsed_from(Window::DimW);
159 TensorShape b_shape_collapsed = b_shape.collapsed_from(Window::DimW);
Ramy Elgammalf26ea2f2023-03-24 11:42:03 +0000160
161 // Create reference
162 SimpleTensor<T> a{ a_shape_collapsed, data_type, 1 };
163 SimpleTensor<T> b{ b_shape_collapsed, data_type, 1 };
164 SimpleTensor<T> c{ output_shape_collapsed, data_type, 1 };
165
166 // Fill reference
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +0000167 fill(a, 2);
168 fill(b, 3);
Ramy Elgammalf26ea2f2023-03-24 11:42:03 +0000169
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +0000170 /* Note: Assuming the usual batch matmul dimensions A = (B x M x K), B = (B x K x N), if transpose_a is set to true, then A is assumed to be (B x K x M),
171 therefore, A must be pre-transposed before passing it to the fixture. And, we transpose A again in the fixture to make it (B x M x K)
172 in order to be able to call reference implementation that works with (B x M x K) input.
173 Similarly, if transpose_b is set to true, then B is assumed to be (B x N x K), B must be pre-transposed before passing it to the fixture. */
Ramy Elgammalf26ea2f2023-03-24 11:42:03 +0000174
175 // Define transposed shapes
176 TensorShape a_transposed_shape(a.shape());
177 a_transposed_shape.set(0, a.shape().y());
178 a_transposed_shape.set(1, a.shape().x());
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +0000179
Ramy Elgammalf26ea2f2023-03-24 11:42:03 +0000180 TensorShape b_transposed_shape(b.shape());
181 b_transposed_shape.set(0, b.shape().y());
182 b_transposed_shape.set(1, b.shape().x());
183
184 // Define transposed tensors
185 SimpleTensor<T> a_transposed{ a_transposed_shape, data_type };
186 SimpleTensor<T> b_transposed{ b_transposed_shape, data_type };
187
188 // pretranspose a if necessary
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +0000189 if(transpose_a)
Ramy Elgammalf26ea2f2023-03-24 11:42:03 +0000190 {
191 a_transposed = reference::permute<T>(a, PermutationVector(1U, 0U));
192 }
Ramy Elgammalf26ea2f2023-03-24 11:42:03 +0000193 // pretranspose b if necessary
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +0000194 if(transpose_b)
Ramy Elgammalf26ea2f2023-03-24 11:42:03 +0000195 {
196 b_transposed = reference::permute<T>(b, PermutationVector(1U, 0U));
197 }
198
199 // Setting beta to 0 will effectively disable C for the
200 // computation of the reference: alpha * A * B + 0 * C
201 // Use transposed tensors if boolean enabled else use original tensors
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +0000202 SimpleTensor<T> result = reference::gemm<T>((transpose_a) ? a_transposed : a, (transpose_b) ? b_transposed : b, c, 1.0f, 0.f);
203 result = reference::activation_layer<T>(result, act_info, QuantizationInfo());
Ramy Elgammalf26ea2f2023-03-24 11:42:03 +0000204
205 // We reshape the gemm output back if the tensor is high dimensional
206 if(output_shape_collapsed != output_shape)
207 {
208 result = reference::reshape_layer(result, output_shape);
209 }
210
211 return result;
212 }
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +0000213
Ramy Elgammalf26ea2f2023-03-24 11:42:03 +0000214 TensorType _target{};
215 SimpleTensor<T> _reference{};
216};
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +0000217
218template <typename TensorType, typename AccessorType, typename FunctionType, typename Settings, typename T>
219class MatMulValidationFixture : public MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>
220{
221public:
222 template <typename...>
223 void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool transpose_a, bool transpose_b, DataType data_type)
224 {
225 MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>::setup(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, ActivationLayerInfo(), 0,
226 Settings());
227 }
228};
229
230template <typename TensorType, typename AccessorType, typename FunctionType, typename Settings, typename T>
231class MatMulValidationWithActivationFixture : public MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>
232{
233public:
234 template <typename...>
235 void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool transpose_a, bool transpose_b, DataType data_type, ActivationLayerInfo act_info)
236 {
237 MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>::setup(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, 0, Settings());
238 }
239};
240
241template <typename TensorType, typename AccessorType, typename FunctionType, typename Settings, typename T>
242class MatMulValidationWithDynamicTensorsFixture : public MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>
243{
244public:
245 template <typename...>
246 void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool transpose_a, bool transpose_b, DataType data_type, ActivationLayerInfo act_info, int num_extra_runs)
247 {
248 MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>::setup(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, num_extra_runs, Settings());
249 }
250};
251
Ramy Elgammalf26ea2f2023-03-24 11:42:03 +0000252} // namespace validation
253} // namespace test
254} // namespace arm_compute
Mohammed Suhail Munshia1b1e412023-03-23 22:21:31 +0000255#endif /* ARM_COMPUTE_TEST_MATMUL_FIXTURE */