blob: e80041c812aa781bb85f75c11261ad775c7e6a70 [file] [log] [blame]
Viet-Hoa Do77bbe2e2023-12-06 11:01:15 +00001/*
2 * Copyright (c) 2023-2024 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#ifdef ARM_COMPUTE_ENABLE_SME2
26
27#include "arm_compute/core/ITensor.h"
28#include "arm_compute/core/Window.h"
29
30namespace arm_compute
31{
32namespace cpu
33{
34
35// SoftMax
36//
37// Steps:
38// * Find max: max_value = max(src)
39// * Regularize: dst[i] = exp(src[i] - max_value)
40// sum_value = sum(dst)
41// * Normalize: dst[i] = dst[i] / sum_value
42void sme2_f32_softmax_kernel( //
43 const float *src,
44 float *dst,
45 float beta,
46 const uintptr_t shape[4],
47 const uintptr_t src_strides[4],
48 const uintptr_t dst_strides[4])
49{
50 // Precondition:
51 // * src_strides[0] == sizeof(float)
52 // * dst_strides[0] == sizeof(float)
53
54 __asm__ volatile(
55 R"(
56 .inst 0xd503477f // smstart
57
58 // Registers
59 //
60 // * x9: temporary, index
61 // * x10: temporary, -inf
62 // * x11: temporary, 0
63 // * x12: temporary, 1.0f
64 // * x13: temporary, body_length
65 //
66 // * x20: index_3
67 // * x21: src_3
68 // * x22: dst_3
69 // * x23: index_2
70 // * x24: src_2
71 // * x25: dst_2
72 // * x26: index_1
73 // * x27: src_1
74 // * x28: dst_1
75 //
76 // * z0: c1
77 // * z1: c2
78 // * z2: c3
79 // * z3: c4
80 // * z4: c5
81 // * z5: shift
82 // * z6: inv_ln2
83 // * z7: neg_ln2_hi
84 // * z8: neg_ln2_lo
85 // * z9: min_input
86 // * z10: 23, 0
87 // * z11: max_value
88 // * z12-z15: x, r_hi, r, r2
89 // * z16-z19: max_value, shift, z, scale, poly
90 // * z20-z21: n, p1, p12345
91 // * z22-z23: n, p23, p2345
92 // * z24-z25: p45
93 // * z26: beta
94 // * z28-z31: sum_value
95 //
96 // * za0-za3: sum_value
97 //
98 // * p0: all-true
99 // * p1: left-over predicate
100 // * p4-p7: underflow
101 // * pn9: all-true
102
103 // Prepares all constant values
104
105 ptrue p0.b
106 .inst 0x25207811 // ptrue pn9.b
107
108 mov w9, #0xfff6 // c1: 0x1.ffffecp-1f = 0x3f7ffff6
109 mov w10, #0xfedb // c2: 0x1.fffdb6p-2f = 0x3efffedb
110 mov w11, #0xaf33 // c3: 0x1.555e66p-3f = 0x3e2aaf33
111 mov w12, #0x9f17 // c4: 0x1.573e2ep-5f = 0x3d2b9f17
112 mov w13, #0x2010 // c5: 0x1.0e4020p-7f = 0x3c072010
113
114 movk w9, #0x3f7f, LSL #16 // c1: 0x1.ffffecp-1f = 0x3f7ffff6
115 movk w10, #0x3eff, LSL #16 // c2: 0x1.fffdb6p-2f = 0x3efffedb
116 movk x11, #0x3e2a, LSL #16 // c3: 0x1.555e66p-3f = 0x3e2aaf33
117 movk w12, #0x3d2b, LSL #16 // c4: 0x1.573e2ep-5f = 0x3d2b9f17
118 movk w13, #0x3c07, LSL #16 // c5: 0x1.0e4020p-7f = 0x3c072010
119
120 dup z0.s, w9 // c1.
121 dup z1.s, w10 // c2.
122 dup z2.s, w11 // c3.
123 dup z3.s, w12 // c4.
124 dup z4.s, w13 // c5.
125
126 mov w9, #0x007f // shift: 2^23 + 127 = 0x1.0000fep23f = 0x4b00007f
127 mov w10, #0xaa3b // inv_ln2: 1 / ln(2) = 0x1.715476p+0f = 0x3fb8aa3b
128 mov w11, #0x7200 // neg_ln2_hi: -ln(2) from bits -1 to -19 = -0x1.62e400p-1f = 0xbf317200
129 mov w12, #0xbe8e // neg_ln2_lo: -ln(2) from bits -20 to -42 = -0x1.7f7d1cp-20f = 0xb5bfbe8e
130 mov w13, #0x47ae // min_input (Approximately ln 2^-125): -86.64 = 0xc2ad47ae
131
132 movk w9, #0x4b00, LSL #16 // shift: 2^23 + 127 = 0x1.0000fep23f = 0x4b00007f
133 movk w10, #0x3fb8, LSL #16 // inv_ln2: 1 / ln(2) = 0x1.715476p+0f = 0x3fb8aa3b
134 movk w11, #0xbf31, LSL #16 // neg_ln2_hi: -ln(2) from bits -1 to -19 = -0x1.62e400p-1f = 0xbf317200
135 movk w12, #0xb5bf, LSL #16 // neg_ln2_lo: -ln(2) from bits -20 to -42 = -0x1.7f7d1cp-20f = 0xb5bfbe8e
136 movk w13, #0xc2ad, LSL #16 // min_input (Approximately ln 2^-125): -86.64 = 0xc2ad47ae
137
138 dup z5.s, w9 // shift
139 dup z6.s, w10 // inv_ln2
140 dup z7.s, w11 // neg_ln2_hi
141 dup z8.s, w12 // neg_ln2_lo
142 dup z9.s, w13 // min_input
143
144 dup z26.s, %w[beta] // beta
145
146 mov w10, #0x0000 // -inf: 0xff800000
147 movk w10, #0xff80 // -inf: 0xff800000
148
149 mov w11, #0 // 0
150
151 // ---------------------------------------------------------------- x13: body_length = (length / vl) * vl
152 cntw x13, ALL, MUL #4
153 udiv x9, %x[length], x13
154 mul x13, x13, x9
155
156 // ==================================================
157 // 3D loop opening
158 // ==================================================
159
160 mov x20, %x[shape_3]
161 mov x21, %x[src]
162 mov x22, %x[dst]
163
164loop_3_start%=:
165 // for index_3 in shape_3 downto 1
166 cmp x20, #0
167 b.eq loop_3_end%=
168 sub x20, x20, #1
169
170 mov x23, %x[shape_2]
171 mov x24, x21
172 mov x25, x22
173
174loop_2_start%=:
175 // for index_2 in shape_2 downto 1
176 cmp x23, #0
177 b.eq loop_2_end%=
178 sub x23, x23, #1
179
180 mov x26, %x[shape_1]
181 mov x27, x24
182 mov x28, x25
183
184loop_1_start%=:
185 // for index_1 in shape_2 downto 1
186 cmp x26, #0
187 b.eq loop_1_end%=
188 sub x26, x26, #1
189
190 // ==================================================
191 // Step 1: Find max
192 // ==================================================
193
194 // ---------------------------------------------------------------- z16-z19: max_value = -inf
195 mov z16.d, z11.d
196 mov z17.d, z11.d
197 mov z18.d, z11.d
198 mov z19.d, z11.d
199
200 // Loop for processing 4 vectors per iteration.
201 mov x9, #0 // x9: index
202 dup z11.s, w10 // z11: max_value = -inf
203
204find_max_body_start%=:
205 cmp x9, x13
206 b.eq find_max_body_end%=
207
208 .inst 0xa009c76c // ld1w {z12.s-z15.s}, pn9/z, [x27, x9, LSL #2] // z12-z15: x
209 .inst 0xc1acb910 // fmax {z16.s-z19.s}, {z16.s-z19.s}, {z12.s-z15.s} // z16-z19: max_value = max(max_value, x)
210
211 incw x9, ALL, MUL #4
212 b find_max_body_start%=
213find_max_body_end%=:
214
215 // Loop for processing the leftover part.
216find_max_leftover_start%=:
217 whilelo p1.s, x9, %x[length]
218 b.none find_max_leftover_end%=
219
220 ld1w z12.s, p1/z, [x27, x9, LSL #2] // z12: x
221 fmax z16.s, p1/m, z16.s, z12.s // z16: max_value = max(max_value, x)
222
223 incw x9
224 b find_max_leftover_start%=
225find_max_leftover_end%=:
226
227 // ---------------------------------------------------------------- z16: max_value
228 .inst 0xc1b2b110 // fmax {z16.s-z17.s}, {z16.s-z17.s}, {z18.s-z19.s}
229 fmax z16.s, p0/m, z16.s, z17.s
230 fmaxv s16, p0, z16.s
231
232 // ---------------------------------------------------------------- z11: max_value
233 dup z11.s, z16.s[0]
234
235 // ==================================================
236 // Step 2: Regularize
237 // ==================================================
238
239 .inst 0xc00800ff // zero {za0.s, za1.s, za2.s, za3.s} za0-za3: sum_value
240
241 // Loop for processing 4 vectors per iteration.
242 mov x9, #0 // ---------------------------------------------------- x9: index
243
244regularize_body_start%=:
245 cmp x9, x13
246 b.eq regularize_body_end%=
247
248 // Loads the input data to 4 consecutive registers ---------------- z12-z15: input_data
249 .inst 0xa009c76c // ld1w {z12.s-z15.s}, pn9/z, [x27, x9, LSL #2]
250
251 // ---------------------------------------------------------------- z12-z15: x = input_data - max_value
252 fsub z12.s, z12.s, z11.s
253 fsub z13.s, z13.s, z11.s
254 fsub z14.s, z14.s, z11.s
255 fsub z15.s, z15.s, z11.s
256
257 // ---------------------------------------------------------------- z12-z15: x = (input_data - max_value) * beta
258 fmul z12.s, z12.s, z26.s
259 fmul z13.s, z13.s, z26.s
260 fmul z14.s, z14.s, z26.s
261 fmul z15.s, z15.s, z26.s
262
263 // ---------------------------------------------------------------- z16-z19: shift
264 mov z16.d, z5.d
265 mov z17.d, z5.d
266 mov z18.d, z5.d
267 mov z19.d, z5.d
268
269 // ---------------------------------------------------------------- p4-p7: underflow = x < min_input
270 fcmlt p4.s, p0/z, z12.s, z9.s
271 fcmlt p5.s, p0/z, z13.s, z9.s
272 fcmlt p6.s, p0/z, z14.s, z9.s
273 fcmlt p7.s, p0/z, z15.s, z9.s
274
275 // ---------------------------------------------------------------- z16-z19: z = shift + x * inv_ln2
276 fmla z16.s, p0/m, z12.s, z6.s
277 fmla z17.s, p0/m, z13.s, z6.s
278 fmla z18.s, p0/m, z14.s, z6.s
279 fmla z19.s, p0/m, z15.s, z6.s
280
281 // ---------------------------------------------------------------- z20-z23: n = z - shift
282 fsub z20.s, z16.s, z5.s
283 fsub z21.s, z17.s, z5.s
284 fsub z22.s, z18.s, z5.s
285 fsub z23.s, z19.s, z5.s
286
287 // ---------------------------------------------------------------- z12-z15: r_hi = x + n * neg_ln2_hi
288 fmla z12.s, p0/m, z20.s, z7.s
289 fmla z13.s, p0/m, z21.s, z7.s
290 fmla z14.s, p0/m, z22.s, z7.s
291 fmla z15.s, p0/m, z23.s, z7.s
292
293 // ---------------------------------------------------------------- z12-z15: r = r_hi + n * neg_ln2_lo
294 fmla z12.s, p0/m, z20.s, z8.s
295 fmla z13.s, p0/m, z21.s, z8.s
296 fmla z14.s, p0/m, z22.s, z8.s
297 fmla z15.s, p0/m, z23.s, z8.s
298
299 // ---------------------------------------------------------------- z16-z19: scale = z << 23 (2^n)
300 dup z10.s, #23
301 urshl z16.s, p0/m, z16.s, z10.s
302 urshl z17.s, p0/m, z17.s, z10.s
303 urshl z18.s, p0/m, z18.s, z10.s
304 urshl z19.s, p0/m, z19.s, z10.s
305
306 // Processes the first 2 vectors.
307
308 // ---------------------------------------------------------------- z20-z21: p1 = r * c1
309 fmul z20.s, z12.s, z0.s
310 fmul z21.s, z13.s, z0.s
311
312 // ---------------------------------------------------------------- z22-z23: p23 = c2
313 mov z22.d, z1.d
314 mov z23.d, z1.d
315
316 // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3
317 fmla z22.s, p0/m, z12.s, z2.s
318 fmla z23.s, p0/m, z13.s, z2.s
319
320 // ---------------------------------------------------------------- z24-z35: c4
321 mov z24.d, z3.d
322 mov z25.d, z3.d
323
324 // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5
325 fmla z24.s, p0/m, z12.s, z4.s
326 fmla z25.s, p0/m, z13.s, z4.s
327
328 // ---------------------------------------------------------------- z12-z13: r2 = r * r
329 fmul z12.s, z12.s, z12.s
330 fmul z13.s, z13.s, z13.s
331
332 // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45
333 fmla z22.s, p0/m, z12.s, z24.s
334 fmla z23.s, p0/m, z13.s, z25.s
335
336 // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345
337 fmla z20.s, p0/m, z12.s, z22.s
338 fmla z21.s, p0/m, z13.s, z23.s
339
340 // ---------------------------------------------------------------- z16-z17: poly = scale + p12345 * scale
341 fmla z16.s, p0/m, z20.s, z16.s
342 fmla z17.s, p0/m, z21.s, z17.s
343
344 // Processes the last 2 vectors
345
346 // ---------------------------------------------------------------- z20-z21: p1 = r * c1
347 fmul z20.s, z14.s, z0.s
348 fmul z21.s, z15.s, z0.s
349
350 // ---------------------------------------------------------------- z22-z23: p23 = c2
351 mov z22.d, z1.d
352 mov z23.d, z1.d
353
354 // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3
355 fmla z22.s, p0/m, z14.s, z2.s
356 fmla z23.s, p0/m, z15.s, z2.s
357
358 // ---------------------------------------------------------------- z24-z35: c4
359 mov z24.d, z3.d
360 mov z25.d, z3.d
361
362 // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5
363 fmla z24.s, p0/m, z14.s, z4.s
364 fmla z25.s, p0/m, z15.s, z4.s
365
366 // ---------------------------------------------------------------- z14-z15: r2 = r * r
367 fmul z14.s, z14.s, z14.s
368 fmul z15.s, z15.s, z15.s
369
370 // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45
371 fmla z22.s, p0/m, z14.s, z24.s
372 fmla z23.s, p0/m, z15.s, z25.s
373
374 // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345
375 fmla z20.s, p0/m, z14.s, z22.s
376 fmla z21.s, p0/m, z15.s, z23.s
377
378 // ---------------------------------------------------------------- z18-z19: poly = scale + p12345 * scale
379 fmla z18.s, p0/m, z20.s, z18.s
380 fmla z19.s, p0/m, z21.s, z19.s
381
382 // ---------------------------------------------------------------- z16-z19: poly = underflow ? 0 : poly
383 dup z10.s, #0
384 sel z16.s, p4, z10.s, z16.s
385 sel z17.s, p5, z10.s, z17.s
386 sel z18.s, p6, z10.s, z18.s
387 sel z19.s, p7, z10.s, z19.s
388
389 // Stores 4 consecutive registers to the output
390 .inst 0xa029c790 // st1w {z16.s-z19.s}, pn9, [x28, x9, LSL #2]
391
392 .inst 0xc1a17e00 // fadd za.s[w11, #0, VGx4], {z16.s-z19.s} za0-za3: sum_value = sum_value + poly
393
394 incw x9, ALL, MUL #4
395 b regularize_body_start%=
396regularize_body_end%=:
397
398 // ---------------------------------------------------------------- z28: sum_value
399 .inst 0xc0066c1c // mova {z28.s-z31.s}, za.s[w11, #0, VGx4]
400 fadd z28.s, z28.s, z29.s
401 fadd z30.s, z30.s, z31.s
402 fadd z28.s, z28.s, z30.s
403
404 // Loop for processing the leftover part.
405regularize_leftover_start%=:
406 whilelo p1.s, x9, %x[length]
407 b.none regularize_leftover_end%=
408
409 ld1w z12.s, p1/z, [x27, x9, LSL #2] // x12: input_data
410
411 fsub z12.s, z12.s, z11.s // z12: x = input_data - max_value
412 fmul z12.s, z12.s, z26.s // z12: x = (input_data - max_value) * beta
413
414 mov z16.d, z5.d // z16: shift
415 fcmlt p4.s, p1/z, z12.s, z9.s // p4: underflow = x < min_input
416 fmla z16.s, p1/m, z12.s, z6.s // z16: z = shift + x * inv_ln2
417 fsub z20.s, z16.s, z5.s // z20: n = z - shift
418 fmla z12.s, p1/m, z20.s, z7.s // z12: r_hi = x + n * neg_ln2_hi
419 fmla z12.s, p1/m, z20.s, z8.s // z12: r = r_hi + n * neg_ln2_lo
420 dup z10.s, #23 // z10: 23
421 urshl z16.s, p1/m, z16.s, z10.s // z16: scale = z << 23 (2^n)
422 fmul z20.s, z12.s, z0.s // z20: p1 = r * c1
423 mov z22.d, z1.d // z22: p23 = c2
424 fmla z22.s, p1/m, z12.s, z2.s // z22: p23 = c2 + r * c3
425 mov z24.d, z3.d // z24: c4
426 fmla z24.s, p1/m, z12.s, z4.s // z24: p45 = c4 + r * c5
427 fmul z12.s, z12.s, z12.s // z12: r2 = r * r
428 fmla z22.s, p1/m, z12.s, z24.s // z22: p2345 = p23 + r2 * p45
429 fmla z20.s, p1/m, z12.s, z22.s // z20: p12345 = p1 + r2 * p2345
430 fmla z16.s, p1/m, z20.s, z16.s // z16: poly = scale + p12345 * scale
431 dup z10.s, #0 // z10: 0
432 sel z16.s, p4, z10.s, z16.s // z16: poly = underflow ? 0 : poly
433
434 st1w z16.s, p1, [x28, x9, LSL #2]
435
436 fadd z28.s, p1/m, z28.s, z16.s // z28: sum_value = sum_value + poly
437
438 incw x9
439 b regularize_leftover_start%=
440regularize_leftover_end%=:
441
442 // ==================================================
443 // Step 3: Normalize
444 // ==================================================
445
446 // ---------------------------------------------------------------- z28: inv_sum_value = 1 / sum_value
447 fmov s29, #1.0 // 1.0f
448 faddv s28, p0, z28.s
449 fdiv s28, s29, s28
450 dup z28.s, z28.s[0]
451
452 // Loop for processing 4 vectors per iteration.
453 mov x9, #0 // x9: index
454
455normalize_body_start%=:
456 cmp x9, x13
457 b.eq normalize_body_end%=
458
459 .inst 0xa009c78c // ld1w {z12.s-z15.s}, pn9/z, [x28, x9, LSL #2] // z12-z15: x
460
461 // ---------------------------------------------------------------- z12-z15: result = x * inv_sum_value
462 fmul z12.s, z12.s, z28.s
463 fmul z13.s, z13.s, z28.s
464 fmul z14.s, z14.s, z28.s
465 fmul z15.s, z15.s, z28.s
466
467 .inst 0xa029c78c // st1w {z12.s-z15.s}, pn9, [x28, x9, LSL #2]
468
469 incw x9, ALL, MUL #4
470 b normalize_body_start%=
471normalize_body_end%=:
472
473 // Loop for processing the leftover part.
474normalize_leftover_start%=:
475 whilelo p1.s, x9, %x[length]
476 b.none normalize_leftover_end%=
477
478 ld1w z12.s, p1/z, [x28, x9, LSL #2] // z12: x
479 fmul z12.s, z12.s, z28.s // z12: result = x * inv_sum_value
480
481 st1w z12.s, p1, [x28, x9, LSL #2]
482
483 incw x9
484 b normalize_leftover_start%=
485normalize_leftover_end%=:
486
487 // ==================================================
488 // 3D loop closing
489 // ==================================================
490
491 add x27, x27, %x[src_stride_1]
492 add x28, x28, %x[dst_stride_1]
493 b loop_1_start%=
494loop_1_end%=:
495
496 add x24, x24, %x[src_stride_2]
497 add x25, x25, %x[dst_stride_2]
498 b loop_2_start%=
499loop_2_end%=:
500
501 add x21, x21, %x[src_stride_3]
502 add x22, x22, %x[dst_stride_3]
503 b loop_3_start%=
504loop_3_end%=:
505
506 .inst 0xd503467f // smstop
507 )"
508 :
509 : [src] "r"(src), [dst] "r"(dst), [beta] "r"(beta), //
510 [shape_1] "r"(shape[1]), [shape_2] "r"(shape[2]), [shape_3] "r"(shape[3]), //
511 [src_stride_1] "r"(src_strides[1]), [src_stride_2] "r"(src_strides[2]),
512 [src_stride_3] "r"(src_strides[3]), //
513 [dst_stride_1] "r"(dst_strides[1]), [dst_stride_2] "r"(dst_strides[2]),
514 [dst_stride_3] "r"(dst_strides[3]), //
515 [length] "r"(shape[0]) //
516 : "cc", "memory", //
517 "p0", "p4", "p5", "p6", "p7", "p9", //
518 "x9", "x10", "x11", "x12", "x13", //
519 "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", //
520 "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", //
521 "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", //
522 "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", //
523 "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" //
524 );
525}
526
527void sme2_fp32_softmax(const ITensor *in, void *const, ITensor *out, const float beta, int axis, const Window &window)
528{
529 ARM_COMPUTE_UNUSED(axis);
530
531 const auto *src_info = in->info();
532 const auto *dst_info = out->info();
533
534 const auto &full_shape = dst_info->tensor_shape();
535 const auto &src_strides = src_info->strides_in_bytes();
536 const auto &dst_strides = dst_info->strides_in_bytes();
537
538 const uintptr_t k_shape[] = {
539 full_shape[0],
540 window.num_iterations(1),
541 window.num_iterations(2),
542 window.num_iterations(3),
543 };
544
545 const uintptr_t k_src_strides[] = {
546 src_strides[0],
547 src_strides[1],
548 src_strides[2],
549 src_strides[3],
550 };
551
552 const uintptr_t k_dst_strides[] = {
553 dst_strides[0],
554 dst_strides[1],
555 dst_strides[2],
556 dst_strides[3],
557 };
558
559 const uintptr_t k_src_offset = window[0].start() * src_strides[0] + //
560 window[1].start() * src_strides[1] + //
561 window[2].start() * src_strides[2] + //
562 window[3].start() * src_strides[3];
563
564 const uintptr_t k_dst_offset = window[0].start() * dst_strides[0] + //
565 window[1].start() * dst_strides[1] + //
566 window[2].start() * dst_strides[2] + //
567 window[3].start() * dst_strides[3];
568
569 const auto *k_src = reinterpret_cast<const float *>(in->buffer() + k_src_offset);
570 auto *k_dst = reinterpret_cast<float *>(out->buffer() + k_dst_offset);
571
572 sme2_f32_softmax_kernel(k_src, k_dst, beta, k_shape, k_src_strides, k_dst_strides);
573}
574
575} // namespace cpu
576} // namespace arm_compute
577
578#endif // ARM_COMPUTE_ENABLE_SME2