Viet-Hoa Do | 77bbe2e | 2023-12-06 11:01:15 +0000 | [diff] [blame] | 1 | /* |
| 2 | * Copyright (c) 2023-2024 Arm Limited. |
| 3 | * |
| 4 | * SPDX-License-Identifier: MIT |
| 5 | * |
| 6 | * Permission is hereby granted, free of charge, to any person obtaining a copy |
| 7 | * of this software and associated documentation files (the "Software"), to |
| 8 | * deal in the Software without restriction, including without limitation the |
| 9 | * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or |
| 10 | * sell copies of the Software, and to permit persons to whom the Software is |
| 11 | * furnished to do so, subject to the following conditions: |
| 12 | * |
| 13 | * The above copyright notice and this permission notice shall be included in all |
| 14 | * copies or substantial portions of the Software. |
| 15 | * |
| 16 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 17 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 18 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 19 | * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 20 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| 21 | * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
| 22 | * SOFTWARE. |
| 23 | */ |
| 24 | |
| 25 | #ifdef ARM_COMPUTE_ENABLE_SME2 |
| 26 | |
| 27 | #include "arm_compute/core/ITensor.h" |
| 28 | #include "arm_compute/core/Window.h" |
| 29 | |
| 30 | namespace arm_compute |
| 31 | { |
| 32 | namespace cpu |
| 33 | { |
| 34 | |
| 35 | // SoftMax |
| 36 | // |
| 37 | // Steps: |
| 38 | // * Find max: max_value = max(src) |
| 39 | // * Regularize: dst[i] = exp(src[i] - max_value) |
| 40 | // sum_value = sum(dst) |
| 41 | // * Normalize: dst[i] = dst[i] / sum_value |
| 42 | void sme2_f32_softmax_kernel( // |
| 43 | const float *src, |
| 44 | float *dst, |
| 45 | float beta, |
| 46 | const uintptr_t shape[4], |
| 47 | const uintptr_t src_strides[4], |
| 48 | const uintptr_t dst_strides[4]) |
| 49 | { |
| 50 | // Precondition: |
| 51 | // * src_strides[0] == sizeof(float) |
| 52 | // * dst_strides[0] == sizeof(float) |
| 53 | |
| 54 | __asm__ volatile( |
| 55 | R"( |
| 56 | .inst 0xd503477f // smstart |
| 57 | |
| 58 | // Registers |
| 59 | // |
| 60 | // * x9: temporary, index |
| 61 | // * x10: temporary, -inf |
| 62 | // * x11: temporary, 0 |
| 63 | // * x12: temporary, 1.0f |
| 64 | // * x13: temporary, body_length |
| 65 | // |
| 66 | // * x20: index_3 |
| 67 | // * x21: src_3 |
| 68 | // * x22: dst_3 |
| 69 | // * x23: index_2 |
| 70 | // * x24: src_2 |
| 71 | // * x25: dst_2 |
| 72 | // * x26: index_1 |
| 73 | // * x27: src_1 |
| 74 | // * x28: dst_1 |
| 75 | // |
| 76 | // * z0: c1 |
| 77 | // * z1: c2 |
| 78 | // * z2: c3 |
| 79 | // * z3: c4 |
| 80 | // * z4: c5 |
| 81 | // * z5: shift |
| 82 | // * z6: inv_ln2 |
| 83 | // * z7: neg_ln2_hi |
| 84 | // * z8: neg_ln2_lo |
| 85 | // * z9: min_input |
| 86 | // * z10: 23, 0 |
| 87 | // * z11: max_value |
| 88 | // * z12-z15: x, r_hi, r, r2 |
| 89 | // * z16-z19: max_value, shift, z, scale, poly |
| 90 | // * z20-z21: n, p1, p12345 |
| 91 | // * z22-z23: n, p23, p2345 |
| 92 | // * z24-z25: p45 |
| 93 | // * z26: beta |
| 94 | // * z28-z31: sum_value |
| 95 | // |
| 96 | // * za0-za3: sum_value |
| 97 | // |
| 98 | // * p0: all-true |
| 99 | // * p1: left-over predicate |
| 100 | // * p4-p7: underflow |
| 101 | // * pn9: all-true |
| 102 | |
| 103 | // Prepares all constant values |
| 104 | |
| 105 | ptrue p0.b |
| 106 | .inst 0x25207811 // ptrue pn9.b |
| 107 | |
| 108 | mov w9, #0xfff6 // c1: 0x1.ffffecp-1f = 0x3f7ffff6 |
| 109 | mov w10, #0xfedb // c2: 0x1.fffdb6p-2f = 0x3efffedb |
| 110 | mov w11, #0xaf33 // c3: 0x1.555e66p-3f = 0x3e2aaf33 |
| 111 | mov w12, #0x9f17 // c4: 0x1.573e2ep-5f = 0x3d2b9f17 |
| 112 | mov w13, #0x2010 // c5: 0x1.0e4020p-7f = 0x3c072010 |
| 113 | |
| 114 | movk w9, #0x3f7f, LSL #16 // c1: 0x1.ffffecp-1f = 0x3f7ffff6 |
| 115 | movk w10, #0x3eff, LSL #16 // c2: 0x1.fffdb6p-2f = 0x3efffedb |
| 116 | movk x11, #0x3e2a, LSL #16 // c3: 0x1.555e66p-3f = 0x3e2aaf33 |
| 117 | movk w12, #0x3d2b, LSL #16 // c4: 0x1.573e2ep-5f = 0x3d2b9f17 |
| 118 | movk w13, #0x3c07, LSL #16 // c5: 0x1.0e4020p-7f = 0x3c072010 |
| 119 | |
| 120 | dup z0.s, w9 // c1. |
| 121 | dup z1.s, w10 // c2. |
| 122 | dup z2.s, w11 // c3. |
| 123 | dup z3.s, w12 // c4. |
| 124 | dup z4.s, w13 // c5. |
| 125 | |
| 126 | mov w9, #0x007f // shift: 2^23 + 127 = 0x1.0000fep23f = 0x4b00007f |
| 127 | mov w10, #0xaa3b // inv_ln2: 1 / ln(2) = 0x1.715476p+0f = 0x3fb8aa3b |
| 128 | mov w11, #0x7200 // neg_ln2_hi: -ln(2) from bits -1 to -19 = -0x1.62e400p-1f = 0xbf317200 |
| 129 | mov w12, #0xbe8e // neg_ln2_lo: -ln(2) from bits -20 to -42 = -0x1.7f7d1cp-20f = 0xb5bfbe8e |
| 130 | mov w13, #0x47ae // min_input (Approximately ln 2^-125): -86.64 = 0xc2ad47ae |
| 131 | |
| 132 | movk w9, #0x4b00, LSL #16 // shift: 2^23 + 127 = 0x1.0000fep23f = 0x4b00007f |
| 133 | movk w10, #0x3fb8, LSL #16 // inv_ln2: 1 / ln(2) = 0x1.715476p+0f = 0x3fb8aa3b |
| 134 | movk w11, #0xbf31, LSL #16 // neg_ln2_hi: -ln(2) from bits -1 to -19 = -0x1.62e400p-1f = 0xbf317200 |
| 135 | movk w12, #0xb5bf, LSL #16 // neg_ln2_lo: -ln(2) from bits -20 to -42 = -0x1.7f7d1cp-20f = 0xb5bfbe8e |
| 136 | movk w13, #0xc2ad, LSL #16 // min_input (Approximately ln 2^-125): -86.64 = 0xc2ad47ae |
| 137 | |
| 138 | dup z5.s, w9 // shift |
| 139 | dup z6.s, w10 // inv_ln2 |
| 140 | dup z7.s, w11 // neg_ln2_hi |
| 141 | dup z8.s, w12 // neg_ln2_lo |
| 142 | dup z9.s, w13 // min_input |
| 143 | |
| 144 | dup z26.s, %w[beta] // beta |
| 145 | |
| 146 | mov w10, #0x0000 // -inf: 0xff800000 |
| 147 | movk w10, #0xff80 // -inf: 0xff800000 |
| 148 | |
| 149 | mov w11, #0 // 0 |
| 150 | |
| 151 | // ---------------------------------------------------------------- x13: body_length = (length / vl) * vl |
| 152 | cntw x13, ALL, MUL #4 |
| 153 | udiv x9, %x[length], x13 |
| 154 | mul x13, x13, x9 |
| 155 | |
| 156 | // ================================================== |
| 157 | // 3D loop opening |
| 158 | // ================================================== |
| 159 | |
| 160 | mov x20, %x[shape_3] |
| 161 | mov x21, %x[src] |
| 162 | mov x22, %x[dst] |
| 163 | |
| 164 | loop_3_start%=: |
| 165 | // for index_3 in shape_3 downto 1 |
| 166 | cmp x20, #0 |
| 167 | b.eq loop_3_end%= |
| 168 | sub x20, x20, #1 |
| 169 | |
| 170 | mov x23, %x[shape_2] |
| 171 | mov x24, x21 |
| 172 | mov x25, x22 |
| 173 | |
| 174 | loop_2_start%=: |
| 175 | // for index_2 in shape_2 downto 1 |
| 176 | cmp x23, #0 |
| 177 | b.eq loop_2_end%= |
| 178 | sub x23, x23, #1 |
| 179 | |
| 180 | mov x26, %x[shape_1] |
| 181 | mov x27, x24 |
| 182 | mov x28, x25 |
| 183 | |
| 184 | loop_1_start%=: |
| 185 | // for index_1 in shape_2 downto 1 |
| 186 | cmp x26, #0 |
| 187 | b.eq loop_1_end%= |
| 188 | sub x26, x26, #1 |
| 189 | |
| 190 | // ================================================== |
| 191 | // Step 1: Find max |
| 192 | // ================================================== |
| 193 | |
| 194 | // ---------------------------------------------------------------- z16-z19: max_value = -inf |
| 195 | mov z16.d, z11.d |
| 196 | mov z17.d, z11.d |
| 197 | mov z18.d, z11.d |
| 198 | mov z19.d, z11.d |
| 199 | |
| 200 | // Loop for processing 4 vectors per iteration. |
| 201 | mov x9, #0 // x9: index |
| 202 | dup z11.s, w10 // z11: max_value = -inf |
| 203 | |
| 204 | find_max_body_start%=: |
| 205 | cmp x9, x13 |
| 206 | b.eq find_max_body_end%= |
| 207 | |
| 208 | .inst 0xa009c76c // ld1w {z12.s-z15.s}, pn9/z, [x27, x9, LSL #2] // z12-z15: x |
| 209 | .inst 0xc1acb910 // fmax {z16.s-z19.s}, {z16.s-z19.s}, {z12.s-z15.s} // z16-z19: max_value = max(max_value, x) |
| 210 | |
| 211 | incw x9, ALL, MUL #4 |
| 212 | b find_max_body_start%= |
| 213 | find_max_body_end%=: |
| 214 | |
| 215 | // Loop for processing the leftover part. |
| 216 | find_max_leftover_start%=: |
| 217 | whilelo p1.s, x9, %x[length] |
| 218 | b.none find_max_leftover_end%= |
| 219 | |
| 220 | ld1w z12.s, p1/z, [x27, x9, LSL #2] // z12: x |
| 221 | fmax z16.s, p1/m, z16.s, z12.s // z16: max_value = max(max_value, x) |
| 222 | |
| 223 | incw x9 |
| 224 | b find_max_leftover_start%= |
| 225 | find_max_leftover_end%=: |
| 226 | |
| 227 | // ---------------------------------------------------------------- z16: max_value |
| 228 | .inst 0xc1b2b110 // fmax {z16.s-z17.s}, {z16.s-z17.s}, {z18.s-z19.s} |
| 229 | fmax z16.s, p0/m, z16.s, z17.s |
| 230 | fmaxv s16, p0, z16.s |
| 231 | |
| 232 | // ---------------------------------------------------------------- z11: max_value |
| 233 | dup z11.s, z16.s[0] |
| 234 | |
| 235 | // ================================================== |
| 236 | // Step 2: Regularize |
| 237 | // ================================================== |
| 238 | |
| 239 | .inst 0xc00800ff // zero {za0.s, za1.s, za2.s, za3.s} za0-za3: sum_value |
| 240 | |
| 241 | // Loop for processing 4 vectors per iteration. |
| 242 | mov x9, #0 // ---------------------------------------------------- x9: index |
| 243 | |
| 244 | regularize_body_start%=: |
| 245 | cmp x9, x13 |
| 246 | b.eq regularize_body_end%= |
| 247 | |
| 248 | // Loads the input data to 4 consecutive registers ---------------- z12-z15: input_data |
| 249 | .inst 0xa009c76c // ld1w {z12.s-z15.s}, pn9/z, [x27, x9, LSL #2] |
| 250 | |
| 251 | // ---------------------------------------------------------------- z12-z15: x = input_data - max_value |
| 252 | fsub z12.s, z12.s, z11.s |
| 253 | fsub z13.s, z13.s, z11.s |
| 254 | fsub z14.s, z14.s, z11.s |
| 255 | fsub z15.s, z15.s, z11.s |
| 256 | |
| 257 | // ---------------------------------------------------------------- z12-z15: x = (input_data - max_value) * beta |
| 258 | fmul z12.s, z12.s, z26.s |
| 259 | fmul z13.s, z13.s, z26.s |
| 260 | fmul z14.s, z14.s, z26.s |
| 261 | fmul z15.s, z15.s, z26.s |
| 262 | |
| 263 | // ---------------------------------------------------------------- z16-z19: shift |
| 264 | mov z16.d, z5.d |
| 265 | mov z17.d, z5.d |
| 266 | mov z18.d, z5.d |
| 267 | mov z19.d, z5.d |
| 268 | |
| 269 | // ---------------------------------------------------------------- p4-p7: underflow = x < min_input |
| 270 | fcmlt p4.s, p0/z, z12.s, z9.s |
| 271 | fcmlt p5.s, p0/z, z13.s, z9.s |
| 272 | fcmlt p6.s, p0/z, z14.s, z9.s |
| 273 | fcmlt p7.s, p0/z, z15.s, z9.s |
| 274 | |
| 275 | // ---------------------------------------------------------------- z16-z19: z = shift + x * inv_ln2 |
| 276 | fmla z16.s, p0/m, z12.s, z6.s |
| 277 | fmla z17.s, p0/m, z13.s, z6.s |
| 278 | fmla z18.s, p0/m, z14.s, z6.s |
| 279 | fmla z19.s, p0/m, z15.s, z6.s |
| 280 | |
| 281 | // ---------------------------------------------------------------- z20-z23: n = z - shift |
| 282 | fsub z20.s, z16.s, z5.s |
| 283 | fsub z21.s, z17.s, z5.s |
| 284 | fsub z22.s, z18.s, z5.s |
| 285 | fsub z23.s, z19.s, z5.s |
| 286 | |
| 287 | // ---------------------------------------------------------------- z12-z15: r_hi = x + n * neg_ln2_hi |
| 288 | fmla z12.s, p0/m, z20.s, z7.s |
| 289 | fmla z13.s, p0/m, z21.s, z7.s |
| 290 | fmla z14.s, p0/m, z22.s, z7.s |
| 291 | fmla z15.s, p0/m, z23.s, z7.s |
| 292 | |
| 293 | // ---------------------------------------------------------------- z12-z15: r = r_hi + n * neg_ln2_lo |
| 294 | fmla z12.s, p0/m, z20.s, z8.s |
| 295 | fmla z13.s, p0/m, z21.s, z8.s |
| 296 | fmla z14.s, p0/m, z22.s, z8.s |
| 297 | fmla z15.s, p0/m, z23.s, z8.s |
| 298 | |
| 299 | // ---------------------------------------------------------------- z16-z19: scale = z << 23 (2^n) |
| 300 | dup z10.s, #23 |
| 301 | urshl z16.s, p0/m, z16.s, z10.s |
| 302 | urshl z17.s, p0/m, z17.s, z10.s |
| 303 | urshl z18.s, p0/m, z18.s, z10.s |
| 304 | urshl z19.s, p0/m, z19.s, z10.s |
| 305 | |
| 306 | // Processes the first 2 vectors. |
| 307 | |
| 308 | // ---------------------------------------------------------------- z20-z21: p1 = r * c1 |
| 309 | fmul z20.s, z12.s, z0.s |
| 310 | fmul z21.s, z13.s, z0.s |
| 311 | |
| 312 | // ---------------------------------------------------------------- z22-z23: p23 = c2 |
| 313 | mov z22.d, z1.d |
| 314 | mov z23.d, z1.d |
| 315 | |
| 316 | // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3 |
| 317 | fmla z22.s, p0/m, z12.s, z2.s |
| 318 | fmla z23.s, p0/m, z13.s, z2.s |
| 319 | |
| 320 | // ---------------------------------------------------------------- z24-z35: c4 |
| 321 | mov z24.d, z3.d |
| 322 | mov z25.d, z3.d |
| 323 | |
| 324 | // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5 |
| 325 | fmla z24.s, p0/m, z12.s, z4.s |
| 326 | fmla z25.s, p0/m, z13.s, z4.s |
| 327 | |
| 328 | // ---------------------------------------------------------------- z12-z13: r2 = r * r |
| 329 | fmul z12.s, z12.s, z12.s |
| 330 | fmul z13.s, z13.s, z13.s |
| 331 | |
| 332 | // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45 |
| 333 | fmla z22.s, p0/m, z12.s, z24.s |
| 334 | fmla z23.s, p0/m, z13.s, z25.s |
| 335 | |
| 336 | // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345 |
| 337 | fmla z20.s, p0/m, z12.s, z22.s |
| 338 | fmla z21.s, p0/m, z13.s, z23.s |
| 339 | |
| 340 | // ---------------------------------------------------------------- z16-z17: poly = scale + p12345 * scale |
| 341 | fmla z16.s, p0/m, z20.s, z16.s |
| 342 | fmla z17.s, p0/m, z21.s, z17.s |
| 343 | |
| 344 | // Processes the last 2 vectors |
| 345 | |
| 346 | // ---------------------------------------------------------------- z20-z21: p1 = r * c1 |
| 347 | fmul z20.s, z14.s, z0.s |
| 348 | fmul z21.s, z15.s, z0.s |
| 349 | |
| 350 | // ---------------------------------------------------------------- z22-z23: p23 = c2 |
| 351 | mov z22.d, z1.d |
| 352 | mov z23.d, z1.d |
| 353 | |
| 354 | // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3 |
| 355 | fmla z22.s, p0/m, z14.s, z2.s |
| 356 | fmla z23.s, p0/m, z15.s, z2.s |
| 357 | |
| 358 | // ---------------------------------------------------------------- z24-z35: c4 |
| 359 | mov z24.d, z3.d |
| 360 | mov z25.d, z3.d |
| 361 | |
| 362 | // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5 |
| 363 | fmla z24.s, p0/m, z14.s, z4.s |
| 364 | fmla z25.s, p0/m, z15.s, z4.s |
| 365 | |
| 366 | // ---------------------------------------------------------------- z14-z15: r2 = r * r |
| 367 | fmul z14.s, z14.s, z14.s |
| 368 | fmul z15.s, z15.s, z15.s |
| 369 | |
| 370 | // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45 |
| 371 | fmla z22.s, p0/m, z14.s, z24.s |
| 372 | fmla z23.s, p0/m, z15.s, z25.s |
| 373 | |
| 374 | // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345 |
| 375 | fmla z20.s, p0/m, z14.s, z22.s |
| 376 | fmla z21.s, p0/m, z15.s, z23.s |
| 377 | |
| 378 | // ---------------------------------------------------------------- z18-z19: poly = scale + p12345 * scale |
| 379 | fmla z18.s, p0/m, z20.s, z18.s |
| 380 | fmla z19.s, p0/m, z21.s, z19.s |
| 381 | |
| 382 | // ---------------------------------------------------------------- z16-z19: poly = underflow ? 0 : poly |
| 383 | dup z10.s, #0 |
| 384 | sel z16.s, p4, z10.s, z16.s |
| 385 | sel z17.s, p5, z10.s, z17.s |
| 386 | sel z18.s, p6, z10.s, z18.s |
| 387 | sel z19.s, p7, z10.s, z19.s |
| 388 | |
| 389 | // Stores 4 consecutive registers to the output |
| 390 | .inst 0xa029c790 // st1w {z16.s-z19.s}, pn9, [x28, x9, LSL #2] |
| 391 | |
| 392 | .inst 0xc1a17e00 // fadd za.s[w11, #0, VGx4], {z16.s-z19.s} za0-za3: sum_value = sum_value + poly |
| 393 | |
| 394 | incw x9, ALL, MUL #4 |
| 395 | b regularize_body_start%= |
| 396 | regularize_body_end%=: |
| 397 | |
| 398 | // ---------------------------------------------------------------- z28: sum_value |
| 399 | .inst 0xc0066c1c // mova {z28.s-z31.s}, za.s[w11, #0, VGx4] |
| 400 | fadd z28.s, z28.s, z29.s |
| 401 | fadd z30.s, z30.s, z31.s |
| 402 | fadd z28.s, z28.s, z30.s |
| 403 | |
| 404 | // Loop for processing the leftover part. |
| 405 | regularize_leftover_start%=: |
| 406 | whilelo p1.s, x9, %x[length] |
| 407 | b.none regularize_leftover_end%= |
| 408 | |
| 409 | ld1w z12.s, p1/z, [x27, x9, LSL #2] // x12: input_data |
| 410 | |
| 411 | fsub z12.s, z12.s, z11.s // z12: x = input_data - max_value |
| 412 | fmul z12.s, z12.s, z26.s // z12: x = (input_data - max_value) * beta |
| 413 | |
| 414 | mov z16.d, z5.d // z16: shift |
| 415 | fcmlt p4.s, p1/z, z12.s, z9.s // p4: underflow = x < min_input |
| 416 | fmla z16.s, p1/m, z12.s, z6.s // z16: z = shift + x * inv_ln2 |
| 417 | fsub z20.s, z16.s, z5.s // z20: n = z - shift |
| 418 | fmla z12.s, p1/m, z20.s, z7.s // z12: r_hi = x + n * neg_ln2_hi |
| 419 | fmla z12.s, p1/m, z20.s, z8.s // z12: r = r_hi + n * neg_ln2_lo |
| 420 | dup z10.s, #23 // z10: 23 |
| 421 | urshl z16.s, p1/m, z16.s, z10.s // z16: scale = z << 23 (2^n) |
| 422 | fmul z20.s, z12.s, z0.s // z20: p1 = r * c1 |
| 423 | mov z22.d, z1.d // z22: p23 = c2 |
| 424 | fmla z22.s, p1/m, z12.s, z2.s // z22: p23 = c2 + r * c3 |
| 425 | mov z24.d, z3.d // z24: c4 |
| 426 | fmla z24.s, p1/m, z12.s, z4.s // z24: p45 = c4 + r * c5 |
| 427 | fmul z12.s, z12.s, z12.s // z12: r2 = r * r |
| 428 | fmla z22.s, p1/m, z12.s, z24.s // z22: p2345 = p23 + r2 * p45 |
| 429 | fmla z20.s, p1/m, z12.s, z22.s // z20: p12345 = p1 + r2 * p2345 |
| 430 | fmla z16.s, p1/m, z20.s, z16.s // z16: poly = scale + p12345 * scale |
| 431 | dup z10.s, #0 // z10: 0 |
| 432 | sel z16.s, p4, z10.s, z16.s // z16: poly = underflow ? 0 : poly |
| 433 | |
| 434 | st1w z16.s, p1, [x28, x9, LSL #2] |
| 435 | |
| 436 | fadd z28.s, p1/m, z28.s, z16.s // z28: sum_value = sum_value + poly |
| 437 | |
| 438 | incw x9 |
| 439 | b regularize_leftover_start%= |
| 440 | regularize_leftover_end%=: |
| 441 | |
| 442 | // ================================================== |
| 443 | // Step 3: Normalize |
| 444 | // ================================================== |
| 445 | |
| 446 | // ---------------------------------------------------------------- z28: inv_sum_value = 1 / sum_value |
| 447 | fmov s29, #1.0 // 1.0f |
| 448 | faddv s28, p0, z28.s |
| 449 | fdiv s28, s29, s28 |
| 450 | dup z28.s, z28.s[0] |
| 451 | |
| 452 | // Loop for processing 4 vectors per iteration. |
| 453 | mov x9, #0 // x9: index |
| 454 | |
| 455 | normalize_body_start%=: |
| 456 | cmp x9, x13 |
| 457 | b.eq normalize_body_end%= |
| 458 | |
| 459 | .inst 0xa009c78c // ld1w {z12.s-z15.s}, pn9/z, [x28, x9, LSL #2] // z12-z15: x |
| 460 | |
| 461 | // ---------------------------------------------------------------- z12-z15: result = x * inv_sum_value |
| 462 | fmul z12.s, z12.s, z28.s |
| 463 | fmul z13.s, z13.s, z28.s |
| 464 | fmul z14.s, z14.s, z28.s |
| 465 | fmul z15.s, z15.s, z28.s |
| 466 | |
| 467 | .inst 0xa029c78c // st1w {z12.s-z15.s}, pn9, [x28, x9, LSL #2] |
| 468 | |
| 469 | incw x9, ALL, MUL #4 |
| 470 | b normalize_body_start%= |
| 471 | normalize_body_end%=: |
| 472 | |
| 473 | // Loop for processing the leftover part. |
| 474 | normalize_leftover_start%=: |
| 475 | whilelo p1.s, x9, %x[length] |
| 476 | b.none normalize_leftover_end%= |
| 477 | |
| 478 | ld1w z12.s, p1/z, [x28, x9, LSL #2] // z12: x |
| 479 | fmul z12.s, z12.s, z28.s // z12: result = x * inv_sum_value |
| 480 | |
| 481 | st1w z12.s, p1, [x28, x9, LSL #2] |
| 482 | |
| 483 | incw x9 |
| 484 | b normalize_leftover_start%= |
| 485 | normalize_leftover_end%=: |
| 486 | |
| 487 | // ================================================== |
| 488 | // 3D loop closing |
| 489 | // ================================================== |
| 490 | |
| 491 | add x27, x27, %x[src_stride_1] |
| 492 | add x28, x28, %x[dst_stride_1] |
| 493 | b loop_1_start%= |
| 494 | loop_1_end%=: |
| 495 | |
| 496 | add x24, x24, %x[src_stride_2] |
| 497 | add x25, x25, %x[dst_stride_2] |
| 498 | b loop_2_start%= |
| 499 | loop_2_end%=: |
| 500 | |
| 501 | add x21, x21, %x[src_stride_3] |
| 502 | add x22, x22, %x[dst_stride_3] |
| 503 | b loop_3_start%= |
| 504 | loop_3_end%=: |
| 505 | |
| 506 | .inst 0xd503467f // smstop |
| 507 | )" |
| 508 | : |
| 509 | : [src] "r"(src), [dst] "r"(dst), [beta] "r"(beta), // |
| 510 | [shape_1] "r"(shape[1]), [shape_2] "r"(shape[2]), [shape_3] "r"(shape[3]), // |
| 511 | [src_stride_1] "r"(src_strides[1]), [src_stride_2] "r"(src_strides[2]), |
| 512 | [src_stride_3] "r"(src_strides[3]), // |
| 513 | [dst_stride_1] "r"(dst_strides[1]), [dst_stride_2] "r"(dst_strides[2]), |
| 514 | [dst_stride_3] "r"(dst_strides[3]), // |
| 515 | [length] "r"(shape[0]) // |
| 516 | : "cc", "memory", // |
| 517 | "p0", "p4", "p5", "p6", "p7", "p9", // |
| 518 | "x9", "x10", "x11", "x12", "x13", // |
| 519 | "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", // |
| 520 | "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", // |
| 521 | "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", // |
| 522 | "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", // |
| 523 | "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" // |
| 524 | ); |
| 525 | } |
| 526 | |
| 527 | void sme2_fp32_softmax(const ITensor *in, void *const, ITensor *out, const float beta, int axis, const Window &window) |
| 528 | { |
| 529 | ARM_COMPUTE_UNUSED(axis); |
| 530 | |
| 531 | const auto *src_info = in->info(); |
| 532 | const auto *dst_info = out->info(); |
| 533 | |
| 534 | const auto &full_shape = dst_info->tensor_shape(); |
| 535 | const auto &src_strides = src_info->strides_in_bytes(); |
| 536 | const auto &dst_strides = dst_info->strides_in_bytes(); |
| 537 | |
| 538 | const uintptr_t k_shape[] = { |
| 539 | full_shape[0], |
| 540 | window.num_iterations(1), |
| 541 | window.num_iterations(2), |
| 542 | window.num_iterations(3), |
| 543 | }; |
| 544 | |
| 545 | const uintptr_t k_src_strides[] = { |
| 546 | src_strides[0], |
| 547 | src_strides[1], |
| 548 | src_strides[2], |
| 549 | src_strides[3], |
| 550 | }; |
| 551 | |
| 552 | const uintptr_t k_dst_strides[] = { |
| 553 | dst_strides[0], |
| 554 | dst_strides[1], |
| 555 | dst_strides[2], |
| 556 | dst_strides[3], |
| 557 | }; |
| 558 | |
| 559 | const uintptr_t k_src_offset = window[0].start() * src_strides[0] + // |
| 560 | window[1].start() * src_strides[1] + // |
| 561 | window[2].start() * src_strides[2] + // |
| 562 | window[3].start() * src_strides[3]; |
| 563 | |
| 564 | const uintptr_t k_dst_offset = window[0].start() * dst_strides[0] + // |
| 565 | window[1].start() * dst_strides[1] + // |
| 566 | window[2].start() * dst_strides[2] + // |
| 567 | window[3].start() * dst_strides[3]; |
| 568 | |
| 569 | const auto *k_src = reinterpret_cast<const float *>(in->buffer() + k_src_offset); |
| 570 | auto *k_dst = reinterpret_cast<float *>(out->buffer() + k_dst_offset); |
| 571 | |
| 572 | sme2_f32_softmax_kernel(k_src, k_dst, beta, k_shape, k_src_strides, k_dst_strides); |
| 573 | } |
| 574 | |
| 575 | } // namespace cpu |
| 576 | } // namespace arm_compute |
| 577 | |
| 578 | #endif // ARM_COMPUTE_ENABLE_SME2 |