blob: 1112dcb2fbbd44f040bcca7bf51ba9fa8e2701f7 [file] [log] [blame]
Ramy Elgammalf26ea2f2023-03-24 11:42:03 +00001/*
2 * Copyright (c) 2023 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef TESTS_VALIDATION_FIXTURES_MATMULFIXTURE
25#define TESTS_VALIDATION_FIXTURES_MATMULFIXTURE
26
27#include "arm_compute/core/Types.h"
28#include "tests/framework/Fixture.h"
29#include "tests/validation/reference/GEMM.h"
30#include "tests/validation/reference/Permute.h"
31#include "tests/validation/reference/Permute.h"
32#include "tests/validation/reference/ReshapeLayer.h"
33#include <random>
34namespace arm_compute
35{
36namespace test
37{
38namespace validation
39{
40template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
41class MatMulValidationFixture : public framework::Fixture
42{
43public:
44 template <typename...>
45 void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool pretranspose_a, bool pretranspose_b, DataType data_type)
46 {
47 // For brevity, the input shapes are assumed to be not-transposed for both Lhs and Rhs matrices.
48 if(pretranspose_a)
49 {
50 permute(shape_a, PermutationVector(1U, 0U));
51 }
52 if(pretranspose_b)
53 {
54 permute(shape_b, PermutationVector(1U, 0U));
55 }
56 _target = compute_target(shape_a, shape_b, output_shape, pretranspose_a, pretranspose_b, data_type);
57 _reference = compute_reference(shape_a, shape_b, output_shape, pretranspose_a, pretranspose_b, data_type);
58 }
59
60protected:
61 template <typename U>
62 void fill(U &&tensor, int i, float lo = -1.f, float hi = 1.f)
63 {
64 switch(tensor.data_type())
65 {
66 case DataType::F16:
67 {
68 arm_compute::utils::uniform_real_distribution_16bit<half> distribution{ float(lo), float(hi) };
69 library->fill(tensor, distribution, i);
70 break;
71 }
72 case DataType::F32:
73 {
74 std::uniform_real_distribution<float> distribution(lo, hi);
75 library->fill(tensor, distribution, i);
76 break;
77 }
78 default:
79 library->fill_tensor_uniform(tensor, i);
80 }
81 }
82 TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &output_shape, bool pretranspose_a, bool pretranspose_b, DataType data_type)
83 {
84 // 1. Create Classes and configure function
85 // Create tensors
86 TensorType a = create_tensor<TensorType>(shape_a, data_type, 1);
87 TensorType b = create_tensor<TensorType>(shape_b, data_type, 1);
88 TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1);
89 FunctionType matmul;
90 // Configure MatMulInfo class
91 MatMulInfo info;
92 info.adj_lhs(pretranspose_a);
93 info.adj_rhs(pretranspose_b);
94 matmul.configure(&a, &b, &dst, info);
95 // Assertions
96 ARM_COMPUTE_ASSERT(a.info()->is_resizable());
97 ARM_COMPUTE_ASSERT(b.info()->is_resizable());
98 ARM_COMPUTE_ASSERT(dst.info()->is_resizable());
99 // Allocate tensors
100 a.allocator()->allocate();
101 b.allocator()->allocate();
102 dst.allocator()->allocate();
103 ARM_COMPUTE_ASSERT(!a.info()->is_resizable());
104 ARM_COMPUTE_ASSERT(!b.info()->is_resizable());
105 ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
106
107 // 2. Fill tensors and run once
108 // Fill tensors
109 fill(AccessorType(a), 0);
110 fill(AccessorType(b), 1);
111 matmul.run(); // First run
112
113 return dst;
114 }
115 SimpleTensor<T> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &output_shape, bool pretranspose_a, bool pretranspose_b, DataType data_type)
116 {
117 // We collapse dimensions > 3 onto dimension 3, i.e. 5D+ tensors will look like 4D
118 // This is necessary unless we choose to extend gemm reference for 5D+ tensors
119 TensorShape output_shape_collapsed = output_shape.collapsed_from(Window::DimW);
120 TensorShape a_shape_collapsed = shape_a.collapsed_from(Window::DimW);
121 TensorShape b_shape_collapsed = shape_b.collapsed_from(Window::DimW);
122
123 // Create reference
124 SimpleTensor<T> a{ a_shape_collapsed, data_type, 1 };
125 SimpleTensor<T> b{ b_shape_collapsed, data_type, 1 };
126 SimpleTensor<T> c{ output_shape_collapsed, data_type, 1 };
127
128 // Fill reference
129 fill(a, 0);
130 fill(b, 1);
131
132 /* Note: Assuming the usual batch matmul dimensions A = (B x M x K), B = (B x K x N), if pretranspose_a is set to true, then A is assumed to be (B x K x M),
133 therefore, A must be pre-transposed before passing it to the fixture. And, we transpose A again in the fixture to make it (B x M x K)
134 in order to be able to call reference implementation that works with (B x M x K) input.
135 Similarly, if pretranspose_b is set to true, then B is assumed to be (B x N x K), B must be pre-transposed before passing it to the fixture. */
136
137 // Define transposed shapes
138 TensorShape a_transposed_shape(a.shape());
139 a_transposed_shape.set(0, a.shape().y());
140 a_transposed_shape.set(1, a.shape().x());
141 TensorShape b_transposed_shape(b.shape());
142 b_transposed_shape.set(0, b.shape().y());
143 b_transposed_shape.set(1, b.shape().x());
144
145 // Define transposed tensors
146 SimpleTensor<T> a_transposed{ a_transposed_shape, data_type };
147 SimpleTensor<T> b_transposed{ b_transposed_shape, data_type };
148
149 // pretranspose a if necessary
150 if(pretranspose_a)
151 {
152 a_transposed = reference::permute<T>(a, PermutationVector(1U, 0U));
153 }
154
155 // pretranspose b if necessary
156 if(pretranspose_b)
157 {
158 b_transposed = reference::permute<T>(b, PermutationVector(1U, 0U));
159 }
160
161 // Setting beta to 0 will effectively disable C for the
162 // computation of the reference: alpha * A * B + 0 * C
163 // Use transposed tensors if boolean enabled else use original tensors
164 SimpleTensor<T> result = reference::gemm<T>((pretranspose_a) ? a_transposed : a, (pretranspose_b) ? b_transposed : b, c, 1.0f, 0.f);
165
166 // We reshape the gemm output back if the tensor is high dimensional
167 if(output_shape_collapsed != output_shape)
168 {
169 result = reference::reshape_layer(result, output_shape);
170 }
171
172 return result;
173 }
174 TensorType _target{};
175 SimpleTensor<T> _reference{};
176};
177} // namespace validation
178} // namespace test
179} // namespace arm_compute
180#endif /* TESTS_VALIDATION_FIXTURES_MATMULFIXTURE */