Ramy Elgammal | f26ea2f | 2023-03-24 11:42:03 +0000 | [diff] [blame] | 1 | /* |
| 2 | * Copyright (c) 2023 Arm Limited. |
| 3 | * |
| 4 | * SPDX-License-Identifier: MIT |
| 5 | * |
| 6 | * Permission is hereby granted, free of charge, to any person obtaining a copy |
| 7 | * of this software and associated documentation files (the "Software"), to |
| 8 | * deal in the Software without restriction, including without limitation the |
| 9 | * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or |
| 10 | * sell copies of the Software, and to permit persons to whom the Software is |
| 11 | * furnished to do so, subject to the following conditions: |
| 12 | * |
| 13 | * The above copyright notice and this permission notice shall be included in all |
| 14 | * copies or substantial portions of the Software. |
| 15 | * |
| 16 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 17 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 18 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 19 | * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 20 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| 21 | * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
| 22 | * SOFTWARE. |
| 23 | */ |
Jakub Sujak | e9b3ee2 | 2023-04-17 12:08:48 +0100 | [diff] [blame] | 24 | #ifndef ACL_TESTS_VALIDATION_FIXTURES_MATMULFIXTURE |
| 25 | #define ACL_TESTS_VALIDATION_FIXTURES_MATMULFIXTURE |
Ramy Elgammal | f26ea2f | 2023-03-24 11:42:03 +0000 | [diff] [blame] | 26 | |
| 27 | #include "arm_compute/core/Types.h" |
Viet-Hoa Do | 9c7c2d2 | 2023-04-11 17:16:27 +0100 | [diff] [blame] | 28 | #include "arm_compute/core/Utils.h" |
| 29 | #include "arm_compute/core/utils/quantization/AsymmHelpers.h" |
Viet-Hoa Do | a62129a | 2023-04-26 15:38:45 +0100 | [diff] [blame] | 30 | #include "src/core/utils/quantization/AsymmHelpers.h" |
Ramy Elgammal | f26ea2f | 2023-03-24 11:42:03 +0000 | [diff] [blame] | 31 | #include "tests/framework/Fixture.h" |
Mohammed Suhail Munshi | a1b1e41 | 2023-03-23 22:21:31 +0000 | [diff] [blame] | 32 | #include "tests/validation/reference/ActivationLayer.h" |
Ramy Elgammal | f26ea2f | 2023-03-24 11:42:03 +0000 | [diff] [blame] | 33 | #include "tests/validation/reference/GEMM.h" |
Viet-Hoa Do | 9c7c2d2 | 2023-04-11 17:16:27 +0100 | [diff] [blame] | 34 | #include "tests/validation/reference/GEMMLowp.h" |
Ramy Elgammal | f26ea2f | 2023-03-24 11:42:03 +0000 | [diff] [blame] | 35 | #include "tests/validation/reference/Permute.h" |
Ramy Elgammal | f26ea2f | 2023-03-24 11:42:03 +0000 | [diff] [blame] | 36 | #include "tests/validation/reference/ReshapeLayer.h" |
Viet-Hoa Do | 9c7c2d2 | 2023-04-11 17:16:27 +0100 | [diff] [blame] | 37 | #include <limits> |
Ramy Elgammal | f26ea2f | 2023-03-24 11:42:03 +0000 | [diff] [blame] | 38 | #include <random> |
Viet-Hoa Do | 9c7c2d2 | 2023-04-11 17:16:27 +0100 | [diff] [blame] | 39 | #include <type_traits> |
Mohammed Suhail Munshi | a1b1e41 | 2023-03-23 22:21:31 +0000 | [diff] [blame] | 40 | |
Ramy Elgammal | f26ea2f | 2023-03-24 11:42:03 +0000 | [diff] [blame] | 41 | namespace arm_compute |
| 42 | { |
| 43 | namespace test |
| 44 | { |
| 45 | namespace validation |
| 46 | { |
Mohammed Suhail Munshi | a1b1e41 | 2023-03-23 22:21:31 +0000 | [diff] [blame] | 47 | template <typename TensorType, typename AccessorType, typename FunctionType, typename Settings, typename T> |
| 48 | class MatMulGenericValidationFixture : public framework::Fixture |
Ramy Elgammal | f26ea2f | 2023-03-24 11:42:03 +0000 | [diff] [blame] | 49 | { |
| 50 | public: |
Mohammed Suhail Munshi | a1b1e41 | 2023-03-23 22:21:31 +0000 | [diff] [blame] | 51 | void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool transpose_a, bool transpose_b, DataType data_type, ActivationLayerInfo act_info, int num_extra_runs, |
Viet-Hoa Do | 9c7c2d2 | 2023-04-11 17:16:27 +0100 | [diff] [blame] | 52 | Settings settings, QuantizationInfo a_qinfo = QuantizationInfo(), QuantizationInfo b_qinfo = QuantizationInfo(), QuantizationInfo o_qinfo = QuantizationInfo()) |
Ramy Elgammal | f26ea2f | 2023-03-24 11:42:03 +0000 | [diff] [blame] | 53 | { |
Mohammed Suhail Munshi | a1b1e41 | 2023-03-23 22:21:31 +0000 | [diff] [blame] | 54 | // For brevity, the input shapes are assumed to be not-transposed for both a and b matrices. |
| 55 | if(transpose_a) |
Ramy Elgammal | f26ea2f | 2023-03-24 11:42:03 +0000 | [diff] [blame] | 56 | { |
| 57 | permute(shape_a, PermutationVector(1U, 0U)); |
| 58 | } |
Mohammed Suhail Munshi | a1b1e41 | 2023-03-23 22:21:31 +0000 | [diff] [blame] | 59 | if(transpose_b) |
Ramy Elgammal | f26ea2f | 2023-03-24 11:42:03 +0000 | [diff] [blame] | 60 | { |
| 61 | permute(shape_b, PermutationVector(1U, 0U)); |
| 62 | } |
Mohammed Suhail Munshi | a1b1e41 | 2023-03-23 22:21:31 +0000 | [diff] [blame] | 63 | |
Viet-Hoa Do | 9c7c2d2 | 2023-04-11 17:16:27 +0100 | [diff] [blame] | 64 | _target = compute_target(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, num_extra_runs, settings, a_qinfo, b_qinfo, o_qinfo); |
| 65 | _reference = compute_reference(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, a_qinfo, b_qinfo, o_qinfo); |
Ramy Elgammal | f26ea2f | 2023-03-24 11:42:03 +0000 | [diff] [blame] | 66 | } |
| 67 | |
| 68 | protected: |
| 69 | template <typename U> |
| 70 | void fill(U &&tensor, int i, float lo = -1.f, float hi = 1.f) |
| 71 | { |
| 72 | switch(tensor.data_type()) |
| 73 | { |
| 74 | case DataType::F16: |
| 75 | { |
| 76 | arm_compute::utils::uniform_real_distribution_16bit<half> distribution{ float(lo), float(hi) }; |
| 77 | library->fill(tensor, distribution, i); |
| 78 | break; |
| 79 | } |
| 80 | case DataType::F32: |
| 81 | { |
| 82 | std::uniform_real_distribution<float> distribution(lo, hi); |
| 83 | library->fill(tensor, distribution, i); |
| 84 | break; |
| 85 | } |
Viet-Hoa Do | 9c7c2d2 | 2023-04-11 17:16:27 +0100 | [diff] [blame] | 86 | case DataType::QASYMM8: |
| 87 | case DataType::QASYMM8_SIGNED: |
Mohammed Suhail Munshi | a1b1e41 | 2023-03-23 22:21:31 +0000 | [diff] [blame] | 88 | { |
Ramy Elgammal | f26ea2f | 2023-03-24 11:42:03 +0000 | [diff] [blame] | 89 | library->fill_tensor_uniform(tensor, i); |
Viet-Hoa Do | 9c7c2d2 | 2023-04-11 17:16:27 +0100 | [diff] [blame] | 90 | break; |
| 91 | } |
| 92 | default: |
| 93 | { |
| 94 | ARM_COMPUTE_ERROR("Unsupported data type."); |
Mohammed Suhail Munshi | a1b1e41 | 2023-03-23 22:21:31 +0000 | [diff] [blame] | 95 | } |
Ramy Elgammal | f26ea2f | 2023-03-24 11:42:03 +0000 | [diff] [blame] | 96 | } |
| 97 | } |
Mohammed Suhail Munshi | a1b1e41 | 2023-03-23 22:21:31 +0000 | [diff] [blame] | 98 | |
| 99 | TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &output_shape, bool transpose_a, bool transpose_b, DataType data_type, |
Viet-Hoa Do | 9c7c2d2 | 2023-04-11 17:16:27 +0100 | [diff] [blame] | 100 | ActivationLayerInfo act_info, int num_extra_runs, const Settings &settings, QuantizationInfo a_qinfo, QuantizationInfo b_qinfo, QuantizationInfo o_qinfo) |
Ramy Elgammal | f26ea2f | 2023-03-24 11:42:03 +0000 | [diff] [blame] | 101 | { |
| 102 | // 1. Create Classes and configure function |
Mohammed Suhail Munshi | a1b1e41 | 2023-03-23 22:21:31 +0000 | [diff] [blame] | 103 | // ---------------------------------------------------- |
Ramy Elgammal | f26ea2f | 2023-03-24 11:42:03 +0000 | [diff] [blame] | 104 | // Create tensors |
Mohammed Suhail Munshi | a1b1e41 | 2023-03-23 22:21:31 +0000 | [diff] [blame] | 105 | // Configure relevant classes and matmul function |
Viet-Hoa Do | 9c7c2d2 | 2023-04-11 17:16:27 +0100 | [diff] [blame] | 106 | TensorType a = create_tensor<TensorType>(shape_a, data_type, 1, a_qinfo); |
| 107 | TensorType b = create_tensor<TensorType>(shape_b, data_type, 1, b_qinfo); |
| 108 | TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1, o_qinfo); |
Mohammed Suhail Munshi | a1b1e41 | 2023-03-23 22:21:31 +0000 | [diff] [blame] | 109 | |
Ramy Elgammal | f26ea2f | 2023-03-24 11:42:03 +0000 | [diff] [blame] | 110 | FunctionType matmul; |
Mohammed Suhail Munshi | a1b1e41 | 2023-03-23 22:21:31 +0000 | [diff] [blame] | 111 | |
Ramy Elgammal | f26ea2f | 2023-03-24 11:42:03 +0000 | [diff] [blame] | 112 | // Configure MatMulInfo class |
Mohammed Suhail Munshi | a1b1e41 | 2023-03-23 22:21:31 +0000 | [diff] [blame] | 113 | MatMulInfo mm_info; |
Mohammed Suhail Munshi | 94abde4 | 2023-05-25 16:48:43 +0100 | [diff] [blame] | 114 | mm_info.adj_lhs(transpose_a).adj_rhs(transpose_b); |
Mohammed Suhail Munshi | a1b1e41 | 2023-03-23 22:21:31 +0000 | [diff] [blame] | 115 | |
| 116 | // Ensure values are dynamic |
| 117 | a.info()->set_are_values_constant(false); |
| 118 | b.info()->set_are_values_constant(false); |
| 119 | |
| 120 | // Configure operator |
Mohammed Suhail Munshi | 94abde4 | 2023-05-25 16:48:43 +0100 | [diff] [blame] | 121 | matmul.configure(&a, &b, &dst, mm_info, settings, act_info); |
Mohammed Suhail Munshi | a1b1e41 | 2023-03-23 22:21:31 +0000 | [diff] [blame] | 122 | |
Ramy Elgammal | f26ea2f | 2023-03-24 11:42:03 +0000 | [diff] [blame] | 123 | // Assertions |
| 124 | ARM_COMPUTE_ASSERT(a.info()->is_resizable()); |
| 125 | ARM_COMPUTE_ASSERT(b.info()->is_resizable()); |
| 126 | ARM_COMPUTE_ASSERT(dst.info()->is_resizable()); |
Mohammed Suhail Munshi | a1b1e41 | 2023-03-23 22:21:31 +0000 | [diff] [blame] | 127 | |
Ramy Elgammal | f26ea2f | 2023-03-24 11:42:03 +0000 | [diff] [blame] | 128 | // Allocate tensors |
| 129 | a.allocator()->allocate(); |
| 130 | b.allocator()->allocate(); |
| 131 | dst.allocator()->allocate(); |
Mohammed Suhail Munshi | a1b1e41 | 2023-03-23 22:21:31 +0000 | [diff] [blame] | 132 | |
Ramy Elgammal | f26ea2f | 2023-03-24 11:42:03 +0000 | [diff] [blame] | 133 | ARM_COMPUTE_ASSERT(!a.info()->is_resizable()); |
| 134 | ARM_COMPUTE_ASSERT(!b.info()->is_resizable()); |
| 135 | ARM_COMPUTE_ASSERT(!dst.info()->is_resizable()); |
| 136 | |
Mohammed Suhail Munshi | a1b1e41 | 2023-03-23 22:21:31 +0000 | [diff] [blame] | 137 | // For multiple runs. |
| 138 | for(int i = 0; i < num_extra_runs; i++) |
| 139 | { |
| 140 | // Stress dynamic tensors by running multiple times. |
| 141 | // -------------------------------------------------------- |
| 142 | // Fill tensors with new seed |
| 143 | // Run function |
| 144 | const int seed_offset = num_extra_runs * 100; |
| 145 | fill(AccessorType(a), seed_offset); |
| 146 | fill(AccessorType(b), seed_offset + 1); |
| 147 | |
| 148 | matmul.run(); |
| 149 | } |
| 150 | |
| 151 | // 2. Final Run for reference comparison |
| 152 | // -------------------------------------------------------- |
| 153 | // Re-fill tensors same seed as reference run |
| 154 | // Compute MatMul operation |
| 155 | fill(AccessorType(a), 2); |
| 156 | fill(AccessorType(b), 3); |
| 157 | |
| 158 | matmul.run(); |
Ramy Elgammal | f26ea2f | 2023-03-24 11:42:03 +0000 | [diff] [blame] | 159 | |
| 160 | return dst; |
| 161 | } |
Mohammed Suhail Munshi | a1b1e41 | 2023-03-23 22:21:31 +0000 | [diff] [blame] | 162 | |
Viet-Hoa Do | 9c7c2d2 | 2023-04-11 17:16:27 +0100 | [diff] [blame] | 163 | template <typename TT> |
Mohammed Suhail Munshi | 94abde4 | 2023-05-25 16:48:43 +0100 | [diff] [blame] | 164 | typename std::enable_if < !std::is_integral<TT>::value, SimpleTensor<TT >>::type |
| 165 | compute_reference_gemm(const SimpleTensor<TT> &a, const SimpleTensor<TT> &b, const SimpleTensor<TT> &c, float alpha, float beta, const QuantizationInfo &o_qinfo) |
Ramy Elgammal | f26ea2f | 2023-03-24 11:42:03 +0000 | [diff] [blame] | 166 | { |
Viet-Hoa Do | a62129a | 2023-04-26 15:38:45 +0100 | [diff] [blame] | 167 | ARM_COMPUTE_UNUSED(o_qinfo); |
Viet-Hoa Do | 9c7c2d2 | 2023-04-11 17:16:27 +0100 | [diff] [blame] | 168 | |
| 169 | return reference::gemm(a, b, c, alpha, beta); |
| 170 | } |
| 171 | |
| 172 | template <typename TT> |
| 173 | typename std::enable_if<std::is_integral<TT>::value, SimpleTensor<TT>>::type |
Mohammed Suhail Munshi | 94abde4 | 2023-05-25 16:48:43 +0100 | [diff] [blame] | 174 | compute_reference_gemm(const SimpleTensor<TT> &a, const SimpleTensor<TT> &b, const SimpleTensor<TT> &c, float alpha, float beta, const QuantizationInfo &o_qinfo) |
Viet-Hoa Do | 9c7c2d2 | 2023-04-11 17:16:27 +0100 | [diff] [blame] | 175 | { |
| 176 | ARM_COMPUTE_UNUSED(alpha, beta); |
| 177 | |
| 178 | const auto aq = a.quantization_info().uniform(); |
| 179 | const auto bq = b.quantization_info().uniform(); |
| 180 | const auto oq = o_qinfo.uniform(); |
| 181 | |
| 182 | const auto multiplier = aq.scale * bq.scale / oq.scale; |
| 183 | |
| 184 | int32_t output_multiplier = 0; |
Mohammed Suhail Munshi | 94abde4 | 2023-05-25 16:48:43 +0100 | [diff] [blame] | 185 | int32_t output_shift = 0; |
Viet-Hoa Do | 9c7c2d2 | 2023-04-11 17:16:27 +0100 | [diff] [blame] | 186 | quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift); |
| 187 | std::vector<int32_t> output_multipliers{ output_multiplier }; |
| 188 | std::vector<int32_t> output_shifts{ output_shift }; |
| 189 | |
Jakub Sujak | e9b3ee2 | 2023-04-17 12:08:48 +0100 | [diff] [blame] | 190 | //The lhs and rhs offsets are negated here to keep the reference aligned with the function implementation where the lhs and rhs offsets are also negated. |
Viet-Hoa Do | 9c7c2d2 | 2023-04-11 17:16:27 +0100 | [diff] [blame] | 191 | const auto tmp = reference::gemmlowp_matrix_multiply_core<int32_t>( |
Mohammed Suhail Munshi | 94abde4 | 2023-05-25 16:48:43 +0100 | [diff] [blame] | 192 | a, b, c.shape(), -aq.offset, -bq.offset); |
Viet-Hoa Do | 9c7c2d2 | 2023-04-11 17:16:27 +0100 | [diff] [blame] | 193 | |
| 194 | auto output = reference::gemmlowp_quantize_down_scale_by_fixedpoint<int32_t, TT>( |
Mohammed Suhail Munshi | 94abde4 | 2023-05-25 16:48:43 +0100 | [diff] [blame] | 195 | tmp, output_multipliers, output_shifts, oq.offset, |
| 196 | std::numeric_limits<int32_t>::lowest(), std::numeric_limits<int32_t>::max()); |
Viet-Hoa Do | 9c7c2d2 | 2023-04-11 17:16:27 +0100 | [diff] [blame] | 197 | output.quantization_info(o_qinfo); |
| 198 | |
| 199 | return output; |
| 200 | } |
| 201 | |
| 202 | SimpleTensor<T> compute_reference(const TensorShape &a_shape, const TensorShape &b_shape, const TensorShape &output_shape, bool transpose_a, bool transpose_b, DataType data_type, |
| 203 | ActivationLayerInfo act_info, QuantizationInfo a_qinfo, QuantizationInfo b_qinfo, QuantizationInfo o_qinfo) |
| 204 | { |
| 205 | // We collapse dimensions > 2 onto dimension 2, i.e. 4D+ tensors will look like 3D |
| 206 | // This is necessary unless we choose to extend gemm reference for 4D+ tensors |
| 207 | TensorShape output_shape_collapsed = output_shape.collapsed_from(Window::DimZ); |
| 208 | TensorShape a_shape_collapsed = a_shape.collapsed_from(Window::DimZ); |
| 209 | TensorShape b_shape_collapsed = b_shape.collapsed_from(Window::DimZ); |
Ramy Elgammal | f26ea2f | 2023-03-24 11:42:03 +0000 | [diff] [blame] | 210 | |
| 211 | // Create reference |
Viet-Hoa Do | 9c7c2d2 | 2023-04-11 17:16:27 +0100 | [diff] [blame] | 212 | SimpleTensor<T> a{ a_shape_collapsed, data_type, 1, a_qinfo }; |
| 213 | SimpleTensor<T> b{ b_shape_collapsed, data_type, 1, b_qinfo }; |
Ramy Elgammal | f26ea2f | 2023-03-24 11:42:03 +0000 | [diff] [blame] | 214 | SimpleTensor<T> c{ output_shape_collapsed, data_type, 1 }; |
| 215 | |
| 216 | // Fill reference |
Mohammed Suhail Munshi | a1b1e41 | 2023-03-23 22:21:31 +0000 | [diff] [blame] | 217 | fill(a, 2); |
| 218 | fill(b, 3); |
Ramy Elgammal | f26ea2f | 2023-03-24 11:42:03 +0000 | [diff] [blame] | 219 | |
Mohammed Suhail Munshi | a1b1e41 | 2023-03-23 22:21:31 +0000 | [diff] [blame] | 220 | /* Note: Assuming the usual batch matmul dimensions A = (B x M x K), B = (B x K x N), if transpose_a is set to true, then A is assumed to be (B x K x M), |
| 221 | therefore, A must be pre-transposed before passing it to the fixture. And, we transpose A again in the fixture to make it (B x M x K) |
| 222 | in order to be able to call reference implementation that works with (B x M x K) input. |
| 223 | Similarly, if transpose_b is set to true, then B is assumed to be (B x N x K), B must be pre-transposed before passing it to the fixture. */ |
Ramy Elgammal | f26ea2f | 2023-03-24 11:42:03 +0000 | [diff] [blame] | 224 | |
| 225 | // Define transposed shapes |
| 226 | TensorShape a_transposed_shape(a.shape()); |
| 227 | a_transposed_shape.set(0, a.shape().y()); |
| 228 | a_transposed_shape.set(1, a.shape().x()); |
Mohammed Suhail Munshi | a1b1e41 | 2023-03-23 22:21:31 +0000 | [diff] [blame] | 229 | |
Ramy Elgammal | f26ea2f | 2023-03-24 11:42:03 +0000 | [diff] [blame] | 230 | TensorShape b_transposed_shape(b.shape()); |
| 231 | b_transposed_shape.set(0, b.shape().y()); |
| 232 | b_transposed_shape.set(1, b.shape().x()); |
| 233 | |
| 234 | // Define transposed tensors |
| 235 | SimpleTensor<T> a_transposed{ a_transposed_shape, data_type }; |
| 236 | SimpleTensor<T> b_transposed{ b_transposed_shape, data_type }; |
| 237 | |
| 238 | // pretranspose a if necessary |
Mohammed Suhail Munshi | a1b1e41 | 2023-03-23 22:21:31 +0000 | [diff] [blame] | 239 | if(transpose_a) |
Ramy Elgammal | f26ea2f | 2023-03-24 11:42:03 +0000 | [diff] [blame] | 240 | { |
| 241 | a_transposed = reference::permute<T>(a, PermutationVector(1U, 0U)); |
| 242 | } |
Ramy Elgammal | f26ea2f | 2023-03-24 11:42:03 +0000 | [diff] [blame] | 243 | // pretranspose b if necessary |
Mohammed Suhail Munshi | a1b1e41 | 2023-03-23 22:21:31 +0000 | [diff] [blame] | 244 | if(transpose_b) |
Ramy Elgammal | f26ea2f | 2023-03-24 11:42:03 +0000 | [diff] [blame] | 245 | { |
| 246 | b_transposed = reference::permute<T>(b, PermutationVector(1U, 0U)); |
| 247 | } |
| 248 | |
| 249 | // Setting beta to 0 will effectively disable C for the |
| 250 | // computation of the reference: alpha * A * B + 0 * C |
| 251 | // Use transposed tensors if boolean enabled else use original tensors |
Viet-Hoa Do | a62129a | 2023-04-26 15:38:45 +0100 | [diff] [blame] | 252 | auto result = compute_reference_gemm<T>((transpose_a) ? a_transposed : a, (transpose_b) ? b_transposed : b, c, 1.0f, 0.f, o_qinfo); |
Viet-Hoa Do | 9c7c2d2 | 2023-04-11 17:16:27 +0100 | [diff] [blame] | 253 | |
| 254 | result = reference::activation_layer<T>(result, act_info, o_qinfo); |
Ramy Elgammal | f26ea2f | 2023-03-24 11:42:03 +0000 | [diff] [blame] | 255 | |
| 256 | // We reshape the gemm output back if the tensor is high dimensional |
| 257 | if(output_shape_collapsed != output_shape) |
| 258 | { |
| 259 | result = reference::reshape_layer(result, output_shape); |
| 260 | } |
| 261 | |
| 262 | return result; |
| 263 | } |
Mohammed Suhail Munshi | a1b1e41 | 2023-03-23 22:21:31 +0000 | [diff] [blame] | 264 | |
Ramy Elgammal | f26ea2f | 2023-03-24 11:42:03 +0000 | [diff] [blame] | 265 | TensorType _target{}; |
| 266 | SimpleTensor<T> _reference{}; |
| 267 | }; |
Mohammed Suhail Munshi | a1b1e41 | 2023-03-23 22:21:31 +0000 | [diff] [blame] | 268 | |
| 269 | template <typename TensorType, typename AccessorType, typename FunctionType, typename Settings, typename T> |
| 270 | class MatMulValidationFixture : public MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T> |
| 271 | { |
| 272 | public: |
Mohammed Suhail Munshi | a1b1e41 | 2023-03-23 22:21:31 +0000 | [diff] [blame] | 273 | void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool transpose_a, bool transpose_b, DataType data_type) |
| 274 | { |
| 275 | MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>::setup(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, ActivationLayerInfo(), 0, |
| 276 | Settings()); |
| 277 | } |
| 278 | }; |
| 279 | |
| 280 | template <typename TensorType, typename AccessorType, typename FunctionType, typename Settings, typename T> |
Mohammed Suhail Munshi | a1b1e41 | 2023-03-23 22:21:31 +0000 | [diff] [blame] | 281 | class MatMulValidationWithDynamicTensorsFixture : public MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T> |
| 282 | { |
| 283 | public: |
Mohammed Suhail Munshi | a1b1e41 | 2023-03-23 22:21:31 +0000 | [diff] [blame] | 284 | void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool transpose_a, bool transpose_b, DataType data_type, ActivationLayerInfo act_info, int num_extra_runs) |
| 285 | { |
| 286 | MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>::setup(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, num_extra_runs, Settings()); |
| 287 | } |
| 288 | }; |
| 289 | |
Viet-Hoa Do | 9c7c2d2 | 2023-04-11 17:16:27 +0100 | [diff] [blame] | 290 | template <typename TensorType, typename AccessorType, typename FunctionType, typename Settings, typename T> |
| 291 | class QuantizedMatMulValidationFixture : public MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T> |
| 292 | { |
| 293 | public: |
Mohammed Suhail Munshi | 94abde4 | 2023-05-25 16:48:43 +0100 | [diff] [blame] | 294 | void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool transpose_a, bool transpose_b, DataType data_type, ActivationLayerInfo act_info, int num_extra_runs, |
| 295 | QuantizationInfo a_qinfo, QuantizationInfo b_qinfo, QuantizationInfo o_qinfo) |
Viet-Hoa Do | 9c7c2d2 | 2023-04-11 17:16:27 +0100 | [diff] [blame] | 296 | { |
Mohammed Suhail Munshi | 94abde4 | 2023-05-25 16:48:43 +0100 | [diff] [blame] | 297 | MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>::setup(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, num_extra_runs, Settings(), |
| 298 | a_qinfo, b_qinfo, o_qinfo); |
| 299 | } |
| 300 | }; |
| 301 | |
| 302 | template <typename TensorType, typename AccessorType, typename FunctionType, typename Settings, typename T> |
| 303 | class MatMulValidationWithActivationFixture : public MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T> |
| 304 | { |
| 305 | public: |
Mohammed Suhail Munshi | 94abde4 | 2023-05-25 16:48:43 +0100 | [diff] [blame] | 306 | void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool transpose_a, bool transpose_b, DataType data_type, ActivationLayerInfo act_info) |
| 307 | { |
| 308 | MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>::setup(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, 0, Settings()); |
| 309 | } |
| 310 | }; |
| 311 | |
| 312 | template <typename TensorType, typename AccessorType, typename FunctionType, typename Settings, typename T> |
| 313 | class MatMulValidationWithActivationAlphaBetaFixture : public MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T> |
| 314 | { |
| 315 | public: |
Mohammed Suhail Munshi | 94abde4 | 2023-05-25 16:48:43 +0100 | [diff] [blame] | 316 | void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool transpose_a, bool transpose_b, DataType data_type, ActivationLayerInfo::ActivationFunction function, |
| 317 | float alpha_beta) |
| 318 | { |
| 319 | ActivationLayerInfo act_info(function, alpha_beta, alpha_beta); |
| 320 | MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>::setup(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, 0, Settings()); |
| 321 | } |
| 322 | }; |
| 323 | |
| 324 | template <typename TensorType, typename AccessorType, typename FunctionType, typename Settings, typename T> |
| 325 | class QuantizedMatMulValidationWithActivationFixture : public MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T> |
| 326 | { |
| 327 | public: |
Mohammed Suhail Munshi | 94abde4 | 2023-05-25 16:48:43 +0100 | [diff] [blame] | 328 | void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool transpose_a, bool transpose_b, DataType data_type, ActivationLayerInfo::ActivationFunction function, |
| 329 | float alpha_beta, int num_extra_runs, |
| 330 | QuantizationInfo a_qinfo, QuantizationInfo b_qinfo, QuantizationInfo o_qinfo) |
| 331 | { |
| 332 | ActivationLayerInfo act_info(function, alpha_beta, alpha_beta); |
| 333 | MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>::setup(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, num_extra_runs, Settings(), |
| 334 | a_qinfo, b_qinfo, o_qinfo); |
Viet-Hoa Do | 9c7c2d2 | 2023-04-11 17:16:27 +0100 | [diff] [blame] | 335 | } |
| 336 | }; |
| 337 | |
Ramy Elgammal | f26ea2f | 2023-03-24 11:42:03 +0000 | [diff] [blame] | 338 | } // namespace validation |
| 339 | } // namespace test |
| 340 | } // namespace arm_compute |
Jakub Sujak | e9b3ee2 | 2023-04-17 12:08:48 +0100 | [diff] [blame] | 341 | #endif /* ACL_TESTS_VALIDATION_FIXTURES_MATMULFIXTURE */ |