COMPMID-2577: Fuse bias addition and activation in gemm assembly kernels

Change-Id: I7f52112d2d05b1ea3d3f3d4b19b8eafab05d6c44
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Reviewed-on: https://review.mlplatform.org/c/2141
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Pablo Marquez <pablo.tello@arm.com>
diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_s32_4x4.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_s32_4x4.hpp
new file mode 100644
index 0000000..a93060f
--- /dev/null
+++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_s32_4x4.hpp
@@ -0,0 +1,433 @@
+/*
+ * Copyright (c) 2019 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+
+#ifdef __aarch64__
+
+template<>
+void MergeResults<4, 4, false>(int32_t *out, const int32_t *in, const int ldout, const int y0, const int ymax, const int x0, const int xmax, const int32_t *bias, Activation act, bool append)
+{
+    UNUSED(act);
+
+    const int32_t *inptr = in;
+    int32_t nullbias[4] = { 0 };
+
+
+    if (!append && !bias)
+    {
+        memset(nullbias, 0, (4 * sizeof(int32_t)));
+    }
+
+    for (int y=y0; y<ymax; y+=4)
+    {
+        int32_t *outptr0 = out + (y * ldout) + x0;
+        int32_t *outptr1 = outptr0 + ldout;
+        int32_t *outptr2 = outptr1 + ldout;
+        int32_t *outptr3 = outptr2 + ldout;
+
+        const int height = ymax - y;
+
+        for (int i=x0; i<xmax; i+=4)
+        {
+            if (append)
+            {
+                switch(height)
+                {
+                case 1:
+                    {
+                        if ((i+3) >= xmax)
+                        {
+                            for (int xi=0; xi<3; xi++)
+                            {
+                                if ((i+xi) < xmax)
+                                {
+                                    *outptr0 += inptr[xi];
+                                    outptr0++;
+                                }
+                            }
+                            inptr += 16;
+                        } else {
+                            /* Optimized routine to copy an entire block */
+                            __asm __volatile (
+                                "ldr q2, [%[outptr0]]\n"
+                                "prfm PLDL1KEEP, [%[inptr], #0x40]\n"
+                                "ldr q10, [%[inptr]]\n"
+                                "prfm PLDL1KEEP, [%[outptr0], #0x20]\n"
+                                "add %[inptr], %[inptr], #0x40\n"
+                                "add v10.4s, v10.4s, v2.4s\n"
+                                "str q10, [%[outptr0]]\n"
+                                "add %[outptr0], %[outptr0], #0x10\n"
+                            : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3),
+                              [inptr] "+r" (inptr)
+                            :
+                            : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "memory"
+                            );
+                        }
+                    }
+                    break;
+
+                case 2:
+                    {
+                        if ((i+3) >= xmax)
+                        {
+                            for (int xi=0; xi<3; xi++)
+                            {
+                                if ((i+xi) < xmax)
+                                {
+                                    *outptr0 += inptr[xi];
+                                    outptr0++;
+                                    *outptr1 += inptr[xi + 4];
+                                    outptr1++;
+                                }
+                            }
+                            inptr += 16;
+                        } else {
+                            /* Optimized routine to copy an entire block */
+                            __asm __volatile (
+                                "ldr q2, [%[outptr0]]\n"
+                                "prfm PLDL1KEEP, [%[inptr], #0x40]\n"
+                                "ldr q10, [%[inptr]]\n"
+                                "prfm PLDL1KEEP, [%[outptr0], #0x20]\n"
+                                "ldr q3, [%[outptr1]]\n"
+                                "prfm PLDL1KEEP, [%[outptr1], #0x20]\n"
+                                "add v10.4s, v10.4s, v2.4s\n"
+                                "ldr q11, [%[inptr], #0x10]\n"
+                                "add %[inptr], %[inptr], #0x40\n"
+                                "add v11.4s, v11.4s, v3.4s\n"
+                                "str q10, [%[outptr0]]\n"
+                                "add %[outptr0], %[outptr0], #0x10\n"
+                                "str q11, [%[outptr1]]\n"
+                                "add %[outptr1], %[outptr1], #0x10\n"
+                            : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3),
+                              [inptr] "+r" (inptr)
+                            :
+                            : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "memory"
+                            );
+                        }
+                    }
+                    break;
+
+                case 3:
+                    {
+                        if ((i+3) >= xmax)
+                        {
+                            for (int xi=0; xi<3; xi++)
+                            {
+                                if ((i+xi) < xmax)
+                                {
+                                    *outptr0 += inptr[xi];
+                                    outptr0++;
+                                    *outptr1 += inptr[xi + 4];
+                                    outptr1++;
+                                    *outptr2 += inptr[xi + 8];
+                                    outptr2++;
+                                }
+                            }
+                            inptr += 16;
+                        } else {
+                            /* Optimized routine to copy an entire block */
+                            __asm __volatile (
+                                "ldr q2, [%[outptr0]]\n"
+                                "prfm PLDL1KEEP, [%[inptr], #0x40]\n"
+                                "ldr q10, [%[inptr]]\n"
+                                "prfm PLDL1KEEP, [%[outptr0], #0x20]\n"
+                                "ldr q3, [%[outptr1]]\n"
+                                "prfm PLDL1KEEP, [%[outptr1], #0x20]\n"
+                                "add v10.4s, v10.4s, v2.4s\n"
+                                "ldr q11, [%[inptr], #0x10]\n"
+                                "ldr q4, [%[outptr2]]\n"
+                                "prfm PLDL1KEEP, [%[outptr2], #0x20]\n"
+                                "ldr q12, [%[inptr], #0x20]\n"
+                                "add %[inptr], %[inptr], #0x40\n"
+                                "add v11.4s, v11.4s, v3.4s\n"
+                                "str q10, [%[outptr0]]\n"
+                                "add %[outptr0], %[outptr0], #0x10\n"
+                                "add v12.4s, v12.4s, v4.4s\n"
+                                "str q11, [%[outptr1]]\n"
+                                "add %[outptr1], %[outptr1], #0x10\n"
+                                "str q12, [%[outptr2]]\n"
+                                "add %[outptr2], %[outptr2], #0x10\n"
+                            : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3),
+                              [inptr] "+r" (inptr)
+                            :
+                            : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "memory"
+                            );
+                        }
+                    }
+                    break;
+
+                default:
+                case 4:
+                    {
+                        if ((i+3) >= xmax)
+                        {
+                            for (int xi=0; xi<3; xi++)
+                            {
+                                if ((i+xi) < xmax)
+                                {
+                                    *outptr0 += inptr[xi];
+                                    outptr0++;
+                                    *outptr1 += inptr[xi + 4];
+                                    outptr1++;
+                                    *outptr2 += inptr[xi + 8];
+                                    outptr2++;
+                                    *outptr3 += inptr[xi + 12];
+                                    outptr3++;
+                                }
+                            }
+                            inptr += 16;
+                        } else {
+                            /* Optimized routine to copy an entire block */
+                            __asm __volatile (
+                                "ldr q2, [%[outptr0]]\n"
+                                "prfm PLDL1KEEP, [%[inptr], #0x40]\n"
+                                "ldr q10, [%[inptr]]\n"
+                                "prfm PLDL1KEEP, [%[outptr0], #0x20]\n"
+                                "ldr q3, [%[outptr1]]\n"
+                                "prfm PLDL1KEEP, [%[outptr1], #0x20]\n"
+                                "add v10.4s, v10.4s, v2.4s\n"
+                                "ldr q11, [%[inptr], #0x10]\n"
+                                "ldr q4, [%[outptr2]]\n"
+                                "prfm PLDL1KEEP, [%[outptr2], #0x20]\n"
+                                "ldr q12, [%[inptr], #0x20]\n"
+                                "prfm PLDL1KEEP, [%[outptr3], #0x20]\n"
+                                "add v11.4s, v11.4s, v3.4s\n"
+                                "str q10, [%[outptr0]]\n"
+                                "ldr q5, [%[outptr3]]\n"
+                                "add %[outptr0], %[outptr0], #0x10\n"
+                                "add v12.4s, v12.4s, v4.4s\n"
+                                "str q11, [%[outptr1]]\n"
+                                "ldr q13, [%[inptr], #0x30]\n"
+                                "add %[outptr1], %[outptr1], #0x10\n"
+                                "add %[inptr], %[inptr], #0x40\n"
+                                "str q12, [%[outptr2]]\n"
+                                "add %[outptr2], %[outptr2], #0x10\n"
+                                "add v13.4s, v13.4s, v5.4s\n"
+                                "str q13, [%[outptr3]]\n"
+                                "add %[outptr3], %[outptr3], #0x10\n"
+                            : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3),
+                              [inptr] "+r" (inptr)
+                            :
+                            : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "memory"
+                            );
+                        }
+                    }
+                    break;
+
+
+                }
+            }
+            else
+            {
+                const int32_t *biasptr = nullbias;
+                if (bias)
+                {
+                    biasptr = bias + i;
+                }
+
+                switch(height)
+                {
+                case 1:
+                    {
+                        if ((i+3) >= xmax)
+                        {
+                            for (int xi=0; xi<3; xi++)
+                            {
+                                if ((i+xi) < xmax)
+                                {
+                                    *outptr0 = biasptr[xi] + inptr[xi];
+                                    outptr0++;
+                                }
+                            }
+                            inptr += 16;
+                        } else {
+                            /* Optimized routine to copy an entire block */
+                            __asm __volatile (
+                                "ldr q2, [%[biasptr]]\n"
+                                "prfm PLDL1KEEP, [%[inptr], #0x40]\n"
+                                "ldr q11, [%[inptr]]\n"
+                                "prfm PSTL1KEEP, [%[outptr0], #0x20]\n"
+                                "add %[inptr], %[inptr], #0x40\n"
+                                "add v11.4s, v11.4s, v2.4s\n"
+                                "str q11, [%[outptr0]]\n"
+                                "add %[outptr0], %[outptr0], #0x10\n"
+                            : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3),
+                              [inptr] "+r" (inptr)
+                            : [biasptr] "r" (biasptr)
+                            : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "memory"
+                            );
+                        }
+                    }
+                    break;
+
+                case 2:
+                    {
+                        if ((i+3) >= xmax)
+                        {
+                            for (int xi=0; xi<3; xi++)
+                            {
+                                if ((i+xi) < xmax)
+                                {
+                                    *outptr0 = biasptr[xi] + inptr[xi];
+                                    outptr0++;
+                                    *outptr1 = biasptr[xi] + inptr[xi + 4];
+                                    outptr1++;
+                                }
+                            }
+                            inptr += 16;
+                        } else {
+                            /* Optimized routine to copy an entire block */
+                            __asm __volatile (
+                                "ldr q2, [%[biasptr]]\n"
+                                "prfm PLDL1KEEP, [%[inptr], #0x40]\n"
+                                "ldr q11, [%[inptr]]\n"
+                                "prfm PSTL1KEEP, [%[outptr0], #0x20]\n"
+                                "ldr q12, [%[inptr], #0x10]\n"
+                                "prfm PSTL1KEEP, [%[outptr1], #0x20]\n"
+                                "add v11.4s, v11.4s, v2.4s\n"
+                                "add %[inptr], %[inptr], #0x40\n"
+                                "add v12.4s, v12.4s, v2.4s\n"
+                                "str q11, [%[outptr0]]\n"
+                                "add %[outptr0], %[outptr0], #0x10\n"
+                                "str q12, [%[outptr1]]\n"
+                                "add %[outptr1], %[outptr1], #0x10\n"
+                            : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3),
+                              [inptr] "+r" (inptr)
+                            : [biasptr] "r" (biasptr)
+                            : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "memory"
+                            );
+                        }
+                    }
+                    break;
+
+                case 3:
+                    {
+                        if ((i+3) >= xmax)
+                        {
+                            for (int xi=0; xi<3; xi++)
+                            {
+                                if ((i+xi) < xmax)
+                                {
+                                    *outptr0 = biasptr[xi] + inptr[xi];
+                                    outptr0++;
+                                    *outptr1 = biasptr[xi] + inptr[xi + 4];
+                                    outptr1++;
+                                    *outptr2 = biasptr[xi] + inptr[xi + 8];
+                                    outptr2++;
+                                }
+                            }
+                            inptr += 16;
+                        } else {
+                            /* Optimized routine to copy an entire block */
+                            __asm __volatile (
+                                "ldr q2, [%[biasptr]]\n"
+                                "prfm PLDL1KEEP, [%[inptr], #0x40]\n"
+                                "ldr q11, [%[inptr]]\n"
+                                "prfm PSTL1KEEP, [%[outptr0], #0x20]\n"
+                                "ldr q12, [%[inptr], #0x10]\n"
+                                "prfm PSTL1KEEP, [%[outptr1], #0x20]\n"
+                                "add v11.4s, v11.4s, v2.4s\n"
+                                "ldr q13, [%[inptr], #0x20]\n"
+                                "prfm PSTL1KEEP, [%[outptr2], #0x20]\n"
+                                "add v12.4s, v12.4s, v2.4s\n"
+                                "add %[inptr], %[inptr], #0x40\n"
+                                "add v13.4s, v13.4s, v2.4s\n"
+                                "str q11, [%[outptr0]]\n"
+                                "add %[outptr0], %[outptr0], #0x10\n"
+                                "str q12, [%[outptr1]]\n"
+                                "add %[outptr1], %[outptr1], #0x10\n"
+                                "str q13, [%[outptr2]]\n"
+                                "add %[outptr2], %[outptr2], #0x10\n"
+                            : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3),
+                              [inptr] "+r" (inptr)
+                            : [biasptr] "r" (biasptr)
+                            : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "memory"
+                            );
+                        }
+                    }
+                    break;
+
+                default:
+                case 4:
+                    {
+                        if ((i+3) >= xmax)
+                        {
+                            for (int xi=0; xi<3; xi++)
+                            {
+                                if ((i+xi) < xmax)
+                                {
+                                    *outptr0 = biasptr[xi] + inptr[xi];
+                                    outptr0++;
+                                    *outptr1 = biasptr[xi] + inptr[xi + 4];
+                                    outptr1++;
+                                    *outptr2 = biasptr[xi] + inptr[xi + 8];
+                                    outptr2++;
+                                    *outptr3 = biasptr[xi] + inptr[xi + 12];
+                                    outptr3++;
+                                }
+                            }
+                            inptr += 16;
+                        } else {
+                            /* Optimized routine to copy an entire block */
+                            __asm __volatile (
+                                "ldr q2, [%[biasptr]]\n"
+                                "prfm PLDL1KEEP, [%[inptr], #0x40]\n"
+                                "ldr q11, [%[inptr]]\n"
+                                "prfm PSTL1KEEP, [%[outptr0], #0x20]\n"
+                                "ldr q12, [%[inptr], #0x10]\n"
+                                "prfm PSTL1KEEP, [%[outptr1], #0x20]\n"
+                                "add v11.4s, v11.4s, v2.4s\n"
+                                "ldr q13, [%[inptr], #0x20]\n"
+                                "ldr q14, [%[inptr], #0x30]\n"
+                                "prfm PSTL1KEEP, [%[outptr2], #0x20]\n"
+                                "add v12.4s, v12.4s, v2.4s\n"
+                                "str q11, [%[outptr0]]\n"
+                                "add v13.4s, v13.4s, v2.4s\n"
+                                "add %[outptr0], %[outptr0], #0x10\n"
+                                "add v14.4s, v14.4s, v2.4s\n"
+                                "str q12, [%[outptr1]]\n"
+                                "add %[outptr1], %[outptr1], #0x10\n"
+                                "prfm PSTL1KEEP, [%[outptr3], #0x20]\n"
+                                "add %[inptr], %[inptr], #0x40\n"
+                                "str q13, [%[outptr2]]\n"
+                                "add %[outptr2], %[outptr2], #0x10\n"
+                                "str q14, [%[outptr3]]\n"
+                                "add %[outptr3], %[outptr3], #0x10\n"
+                            : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3),
+                              [inptr] "+r" (inptr)
+                            : [biasptr] "r" (biasptr)
+                            : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "memory"
+                            );
+                        }
+                    }
+                    break;
+
+
+                }
+            }
+        }
+    }
+}
+
+#endif // __aarch64__