blob: 159039a32051101d9d454f1a6a70219adbb5b4de [file] [log] [blame]
/*
* Copyright (c) 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to
* deal in the Software without restriction, including without limitation the
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
* sell copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifdef ARM_COMPUTE_ENABLE_SME2
#include "arm_compute/core/ITensor.h"
#include "arm_compute/core/Window.h"
namespace arm_compute
{
namespace cpu
{
// SoftMax
//
// Steps:
// * Find max: max_value = max(src)
// * Regularize: dst[i] = exp(src[i] - max_value)
// sum_value = sum(dst)
// * Normalize: dst[i] = dst[i] / sum_value
void sme2_f32_softmax_kernel( //
const float *src,
float *dst,
float beta,
const uintptr_t shape[4],
const uintptr_t src_strides[4],
const uintptr_t dst_strides[4])
{
// Precondition:
// * src_strides[0] == sizeof(float)
// * dst_strides[0] == sizeof(float)
__asm__ volatile(
R"(
.inst 0xd503477f // smstart
// Registers
//
// * x9: temporary, index
// * x10: temporary, -inf
// * x11: temporary, 0
// * x12: temporary, 1.0f
// * x13: temporary, body_length
//
// * x20: index_3
// * x21: src_3
// * x22: dst_3
// * x23: index_2
// * x24: src_2
// * x25: dst_2
// * x26: index_1
// * x27: src_1
// * x28: dst_1
//
// * z0: c1
// * z1: c2
// * z2: c3
// * z3: c4
// * z4: c5
// * z5: shift
// * z6: inv_ln2
// * z7: neg_ln2_hi
// * z8: neg_ln2_lo
// * z9: min_input
// * z10: 23, 0
// * z11: max_value
// * z12-z15: x, r_hi, r, r2
// * z16-z19: max_value, shift, z, scale, poly
// * z20-z21: n, p1, p12345
// * z22-z23: n, p23, p2345
// * z24-z25: p45
// * z26: beta
// * z28-z31: sum_value
//
// * za0-za3: sum_value
//
// * p0: all-true
// * p1: left-over predicate
// * p4-p7: underflow
// * pn9: all-true
// Prepares all constant values
ptrue p0.b
.inst 0x25207811 // ptrue pn9.b
mov w9, #0xfff6 // c1: 0x1.ffffecp-1f = 0x3f7ffff6
mov w10, #0xfedb // c2: 0x1.fffdb6p-2f = 0x3efffedb
mov w11, #0xaf33 // c3: 0x1.555e66p-3f = 0x3e2aaf33
mov w12, #0x9f17 // c4: 0x1.573e2ep-5f = 0x3d2b9f17
mov w13, #0x2010 // c5: 0x1.0e4020p-7f = 0x3c072010
movk w9, #0x3f7f, LSL #16 // c1: 0x1.ffffecp-1f = 0x3f7ffff6
movk w10, #0x3eff, LSL #16 // c2: 0x1.fffdb6p-2f = 0x3efffedb
movk x11, #0x3e2a, LSL #16 // c3: 0x1.555e66p-3f = 0x3e2aaf33
movk w12, #0x3d2b, LSL #16 // c4: 0x1.573e2ep-5f = 0x3d2b9f17
movk w13, #0x3c07, LSL #16 // c5: 0x1.0e4020p-7f = 0x3c072010
dup z0.s, w9 // c1.
dup z1.s, w10 // c2.
dup z2.s, w11 // c3.
dup z3.s, w12 // c4.
dup z4.s, w13 // c5.
mov w9, #0x007f // shift: 2^23 + 127 = 0x1.0000fep23f = 0x4b00007f
mov w10, #0xaa3b // inv_ln2: 1 / ln(2) = 0x1.715476p+0f = 0x3fb8aa3b
mov w11, #0x7200 // neg_ln2_hi: -ln(2) from bits -1 to -19 = -0x1.62e400p-1f = 0xbf317200
mov w12, #0xbe8e // neg_ln2_lo: -ln(2) from bits -20 to -42 = -0x1.7f7d1cp-20f = 0xb5bfbe8e
mov w13, #0x47ae // min_input (Approximately ln 2^-125): -86.64 = 0xc2ad47ae
movk w9, #0x4b00, LSL #16 // shift: 2^23 + 127 = 0x1.0000fep23f = 0x4b00007f
movk w10, #0x3fb8, LSL #16 // inv_ln2: 1 / ln(2) = 0x1.715476p+0f = 0x3fb8aa3b
movk w11, #0xbf31, LSL #16 // neg_ln2_hi: -ln(2) from bits -1 to -19 = -0x1.62e400p-1f = 0xbf317200
movk w12, #0xb5bf, LSL #16 // neg_ln2_lo: -ln(2) from bits -20 to -42 = -0x1.7f7d1cp-20f = 0xb5bfbe8e
movk w13, #0xc2ad, LSL #16 // min_input (Approximately ln 2^-125): -86.64 = 0xc2ad47ae
dup z5.s, w9 // shift
dup z6.s, w10 // inv_ln2
dup z7.s, w11 // neg_ln2_hi
dup z8.s, w12 // neg_ln2_lo
dup z9.s, w13 // min_input
dup z26.s, %w[beta] // beta
mov w10, #0x0000 // -inf: 0xff800000
movk w10, #0xff80 // -inf: 0xff800000
mov w11, #0 // 0
// ---------------------------------------------------------------- x13: body_length = (length / vl) * vl
cntw x13, ALL, MUL #4
udiv x9, %x[length], x13
mul x13, x13, x9
// ==================================================
// 3D loop opening
// ==================================================
mov x20, %x[shape_3]
mov x21, %x[src]
mov x22, %x[dst]
loop_3_start%=:
// for index_3 in shape_3 downto 1
cmp x20, #0
b.eq loop_3_end%=
sub x20, x20, #1
mov x23, %x[shape_2]
mov x24, x21
mov x25, x22
loop_2_start%=:
// for index_2 in shape_2 downto 1
cmp x23, #0
b.eq loop_2_end%=
sub x23, x23, #1
mov x26, %x[shape_1]
mov x27, x24
mov x28, x25
loop_1_start%=:
// for index_1 in shape_2 downto 1
cmp x26, #0
b.eq loop_1_end%=
sub x26, x26, #1
// ==================================================
// Step 1: Find max
// ==================================================
// Loop for processing 4 vectors per iteration.
mov x9, #0 // x9: index
dup z11.s, w10 // z11: max_value = -inf
// ---------------------------------------------------------------- z16-z19: max_value = -inf
mov z16.d, z11.d
mov z17.d, z11.d
mov z18.d, z11.d
mov z19.d, z11.d
find_max_body_start%=:
cmp x9, x13
b.eq find_max_body_end%=
.inst 0xa009c76c // ld1w {z12.s-z15.s}, pn9/z, [x27, x9, LSL #2] // z12-z15: x
.inst 0xc1acb910 // fmax {z16.s-z19.s}, {z16.s-z19.s}, {z12.s-z15.s} // z16-z19: max_value = max(max_value, x)
incw x9, ALL, MUL #4
b find_max_body_start%=
find_max_body_end%=:
// Loop for processing the leftover part.
find_max_leftover_start%=:
whilelo p1.s, x9, %x[length]
b.none find_max_leftover_end%=
ld1w z12.s, p1/z, [x27, x9, LSL #2] // z12: x
fmax z16.s, p1/m, z16.s, z12.s // z16: max_value = max(max_value, x)
incw x9
b find_max_leftover_start%=
find_max_leftover_end%=:
// ---------------------------------------------------------------- z16: max_value
.inst 0xc1b2b110 // fmax {z16.s-z17.s}, {z16.s-z17.s}, {z18.s-z19.s}
fmax z16.s, p0/m, z16.s, z17.s
fmaxv s16, p0, z16.s
// ---------------------------------------------------------------- z11: max_value
dup z11.s, z16.s[0]
// ==================================================
// Step 2: Regularize
// ==================================================
.inst 0xc00800ff // zero {za0.s, za1.s, za2.s, za3.s} za0-za3: sum_value
// Loop for processing 4 vectors per iteration.
mov x9, #0 // ---------------------------------------------------- x9: index
regularize_body_start%=:
cmp x9, x13
b.eq regularize_body_end%=
// Loads the input data to 4 consecutive registers ---------------- z12-z15: input_data
.inst 0xa009c76c // ld1w {z12.s-z15.s}, pn9/z, [x27, x9, LSL #2]
// ---------------------------------------------------------------- z12-z15: x = input_data - max_value
fsub z12.s, z12.s, z11.s
fsub z13.s, z13.s, z11.s
fsub z14.s, z14.s, z11.s
fsub z15.s, z15.s, z11.s
// ---------------------------------------------------------------- z12-z15: x = (input_data - max_value) * beta
fmul z12.s, z12.s, z26.s
fmul z13.s, z13.s, z26.s
fmul z14.s, z14.s, z26.s
fmul z15.s, z15.s, z26.s
// ---------------------------------------------------------------- z16-z19: shift
mov z16.d, z5.d
mov z17.d, z5.d
mov z18.d, z5.d
mov z19.d, z5.d
// ---------------------------------------------------------------- p4-p7: underflow = x < min_input
fcmlt p4.s, p0/z, z12.s, z9.s
fcmlt p5.s, p0/z, z13.s, z9.s
fcmlt p6.s, p0/z, z14.s, z9.s
fcmlt p7.s, p0/z, z15.s, z9.s
// ---------------------------------------------------------------- z16-z19: z = shift + x * inv_ln2
fmla z16.s, p0/m, z12.s, z6.s
fmla z17.s, p0/m, z13.s, z6.s
fmla z18.s, p0/m, z14.s, z6.s
fmla z19.s, p0/m, z15.s, z6.s
// ---------------------------------------------------------------- z20-z23: n = z - shift
fsub z20.s, z16.s, z5.s
fsub z21.s, z17.s, z5.s
fsub z22.s, z18.s, z5.s
fsub z23.s, z19.s, z5.s
// ---------------------------------------------------------------- z12-z15: r_hi = x + n * neg_ln2_hi
fmla z12.s, p0/m, z20.s, z7.s
fmla z13.s, p0/m, z21.s, z7.s
fmla z14.s, p0/m, z22.s, z7.s
fmla z15.s, p0/m, z23.s, z7.s
// ---------------------------------------------------------------- z12-z15: r = r_hi + n * neg_ln2_lo
fmla z12.s, p0/m, z20.s, z8.s
fmla z13.s, p0/m, z21.s, z8.s
fmla z14.s, p0/m, z22.s, z8.s
fmla z15.s, p0/m, z23.s, z8.s
// ---------------------------------------------------------------- z16-z19: scale = z << 23 (2^n)
dup z10.s, #23
urshl z16.s, p0/m, z16.s, z10.s
urshl z17.s, p0/m, z17.s, z10.s
urshl z18.s, p0/m, z18.s, z10.s
urshl z19.s, p0/m, z19.s, z10.s
// Processes the first 2 vectors.
// ---------------------------------------------------------------- z20-z21: p1 = r * c1
fmul z20.s, z12.s, z0.s
fmul z21.s, z13.s, z0.s
// ---------------------------------------------------------------- z22-z23: p23 = c2
mov z22.d, z1.d
mov z23.d, z1.d
// ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3
fmla z22.s, p0/m, z12.s, z2.s
fmla z23.s, p0/m, z13.s, z2.s
// ---------------------------------------------------------------- z24-z35: c4
mov z24.d, z3.d
mov z25.d, z3.d
// ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5
fmla z24.s, p0/m, z12.s, z4.s
fmla z25.s, p0/m, z13.s, z4.s
// ---------------------------------------------------------------- z12-z13: r2 = r * r
fmul z12.s, z12.s, z12.s
fmul z13.s, z13.s, z13.s
// ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45
fmla z22.s, p0/m, z12.s, z24.s
fmla z23.s, p0/m, z13.s, z25.s
// ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345
fmla z20.s, p0/m, z12.s, z22.s
fmla z21.s, p0/m, z13.s, z23.s
// ---------------------------------------------------------------- z16-z17: poly = scale + p12345 * scale
fmla z16.s, p0/m, z20.s, z16.s
fmla z17.s, p0/m, z21.s, z17.s
// Processes the last 2 vectors
// ---------------------------------------------------------------- z20-z21: p1 = r * c1
fmul z20.s, z14.s, z0.s
fmul z21.s, z15.s, z0.s
// ---------------------------------------------------------------- z22-z23: p23 = c2
mov z22.d, z1.d
mov z23.d, z1.d
// ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3
fmla z22.s, p0/m, z14.s, z2.s
fmla z23.s, p0/m, z15.s, z2.s
// ---------------------------------------------------------------- z24-z35: c4
mov z24.d, z3.d
mov z25.d, z3.d
// ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5
fmla z24.s, p0/m, z14.s, z4.s
fmla z25.s, p0/m, z15.s, z4.s
// ---------------------------------------------------------------- z14-z15: r2 = r * r
fmul z14.s, z14.s, z14.s
fmul z15.s, z15.s, z15.s
// ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45
fmla z22.s, p0/m, z14.s, z24.s
fmla z23.s, p0/m, z15.s, z25.s
// ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345
fmla z20.s, p0/m, z14.s, z22.s
fmla z21.s, p0/m, z15.s, z23.s
// ---------------------------------------------------------------- z18-z19: poly = scale + p12345 * scale
fmla z18.s, p0/m, z20.s, z18.s
fmla z19.s, p0/m, z21.s, z19.s
// ---------------------------------------------------------------- z16-z19: poly = underflow ? 0 : poly
dup z10.s, #0
sel z16.s, p4, z10.s, z16.s
sel z17.s, p5, z10.s, z17.s
sel z18.s, p6, z10.s, z18.s
sel z19.s, p7, z10.s, z19.s
// Stores 4 consecutive registers to the output
.inst 0xa029c790 // st1w {z16.s-z19.s}, pn9, [x28, x9, LSL #2]
.inst 0xc1a17e00 // fadd za.s[w11, #0, VGx4], {z16.s-z19.s} za0-za3: sum_value = sum_value + poly
incw x9, ALL, MUL #4
b regularize_body_start%=
regularize_body_end%=:
// ---------------------------------------------------------------- z28: sum_value
.inst 0xc0066c1c // mova {z28.s-z31.s}, za.s[w11, #0, VGx4]
fadd z28.s, z28.s, z29.s
fadd z30.s, z30.s, z31.s
fadd z28.s, z28.s, z30.s
// Loop for processing the leftover part.
regularize_leftover_start%=:
whilelo p1.s, x9, %x[length]
b.none regularize_leftover_end%=
ld1w z12.s, p1/z, [x27, x9, LSL #2] // x12: input_data
fsub z12.s, z12.s, z11.s // z12: x = input_data - max_value
fmul z12.s, z12.s, z26.s // z12: x = (input_data - max_value) * beta
mov z16.d, z5.d // z16: shift
fcmlt p4.s, p1/z, z12.s, z9.s // p4: underflow = x < min_input
fmla z16.s, p1/m, z12.s, z6.s // z16: z = shift + x * inv_ln2
fsub z20.s, z16.s, z5.s // z20: n = z - shift
fmla z12.s, p1/m, z20.s, z7.s // z12: r_hi = x + n * neg_ln2_hi
fmla z12.s, p1/m, z20.s, z8.s // z12: r = r_hi + n * neg_ln2_lo
dup z10.s, #23 // z10: 23
urshl z16.s, p1/m, z16.s, z10.s // z16: scale = z << 23 (2^n)
fmul z20.s, z12.s, z0.s // z20: p1 = r * c1
mov z22.d, z1.d // z22: p23 = c2
fmla z22.s, p1/m, z12.s, z2.s // z22: p23 = c2 + r * c3
mov z24.d, z3.d // z24: c4
fmla z24.s, p1/m, z12.s, z4.s // z24: p45 = c4 + r * c5
fmul z12.s, z12.s, z12.s // z12: r2 = r * r
fmla z22.s, p1/m, z12.s, z24.s // z22: p2345 = p23 + r2 * p45
fmla z20.s, p1/m, z12.s, z22.s // z20: p12345 = p1 + r2 * p2345
fmla z16.s, p1/m, z20.s, z16.s // z16: poly = scale + p12345 * scale
dup z10.s, #0 // z10: 0
sel z16.s, p4, z10.s, z16.s // z16: poly = underflow ? 0 : poly
st1w z16.s, p1, [x28, x9, LSL #2]
fadd z28.s, p1/m, z28.s, z16.s // z28: sum_value = sum_value + poly
incw x9
b regularize_leftover_start%=
regularize_leftover_end%=:
// ==================================================
// Step 3: Normalize
// ==================================================
// ---------------------------------------------------------------- z28: inv_sum_value = 1 / sum_value
fmov s29, #1.0 // 1.0f
faddv s28, p0, z28.s
fdiv s28, s29, s28
dup z28.s, z28.s[0]
// Loop for processing 4 vectors per iteration.
mov x9, #0 // x9: index
normalize_body_start%=:
cmp x9, x13
b.eq normalize_body_end%=
.inst 0xa009c78c // ld1w {z12.s-z15.s}, pn9/z, [x28, x9, LSL #2] // z12-z15: x
// ---------------------------------------------------------------- z12-z15: result = x * inv_sum_value
fmul z12.s, z12.s, z28.s
fmul z13.s, z13.s, z28.s
fmul z14.s, z14.s, z28.s
fmul z15.s, z15.s, z28.s
.inst 0xa029c78c // st1w {z12.s-z15.s}, pn9, [x28, x9, LSL #2]
incw x9, ALL, MUL #4
b normalize_body_start%=
normalize_body_end%=:
// Loop for processing the leftover part.
normalize_leftover_start%=:
whilelo p1.s, x9, %x[length]
b.none normalize_leftover_end%=
ld1w z12.s, p1/z, [x28, x9, LSL #2] // z12: x
fmul z12.s, z12.s, z28.s // z12: result = x * inv_sum_value
st1w z12.s, p1, [x28, x9, LSL #2]
incw x9
b normalize_leftover_start%=
normalize_leftover_end%=:
// ==================================================
// 3D loop closing
// ==================================================
add x27, x27, %x[src_stride_1]
add x28, x28, %x[dst_stride_1]
b loop_1_start%=
loop_1_end%=:
add x24, x24, %x[src_stride_2]
add x25, x25, %x[dst_stride_2]
b loop_2_start%=
loop_2_end%=:
add x21, x21, %x[src_stride_3]
add x22, x22, %x[dst_stride_3]
b loop_3_start%=
loop_3_end%=:
.inst 0xd503467f // smstop
)"
:
: [src] "r"(src), [dst] "r"(dst), [beta] "r"(beta), //
[shape_1] "r"(shape[1]), [shape_2] "r"(shape[2]), [shape_3] "r"(shape[3]), //
[src_stride_1] "r"(src_strides[1]), [src_stride_2] "r"(src_strides[2]),
[src_stride_3] "r"(src_strides[3]), //
[dst_stride_1] "r"(dst_strides[1]), [dst_stride_2] "r"(dst_strides[2]),
[dst_stride_3] "r"(dst_strides[3]), //
[length] "r"(shape[0]) //
: "cc", "memory", //
"p0", "p4", "p5", "p6", "p7", "p9", //
"x9", "x10", "x11", "x12", "x13", //
"x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", //
"z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", //
"z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", //
"z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", //
"z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" //
);
}
void sme2_fp32_softmax(const ITensor *in, void *const, ITensor *out, const float beta, int axis, const Window &window)
{
ARM_COMPUTE_UNUSED(axis);
const auto *src_info = in->info();
const auto *dst_info = out->info();
const auto &full_shape = dst_info->tensor_shape();
const auto &src_strides = src_info->strides_in_bytes();
const auto &dst_strides = dst_info->strides_in_bytes();
const uintptr_t k_shape[] = {
full_shape[0],
window.num_iterations(1),
window.num_iterations(2),
window.num_iterations(3),
};
const uintptr_t k_src_strides[] = {
src_strides[0],
src_strides[1],
src_strides[2],
src_strides[3],
};
const uintptr_t k_dst_strides[] = {
dst_strides[0],
dst_strides[1],
dst_strides[2],
dst_strides[3],
};
const uintptr_t k_src_offset = window[0].start() * src_strides[0] + //
window[1].start() * src_strides[1] + //
window[2].start() * src_strides[2] + //
window[3].start() * src_strides[3];
const uintptr_t k_dst_offset = window[0].start() * dst_strides[0] + //
window[1].start() * dst_strides[1] + //
window[2].start() * dst_strides[2] + //
window[3].start() * dst_strides[3];
const auto *k_src = reinterpret_cast<const float *>(in->buffer() + k_src_offset);
auto *k_dst = reinterpret_cast<float *>(out->buffer() + k_dst_offset);
sme2_f32_softmax_kernel(k_src, k_dst, beta, k_shape, k_src_strides, k_dst_strides);
}
} // namespace cpu
} // namespace arm_compute
#endif // ARM_COMPUTE_ENABLE_SME2