Integrate new pretranspose_b_array with extra fused transpose of B

This patch fuses the transposition taking place in Acl with the transformations done in arm_gemm (called pretranspose_b_array) if the underlying kernel and transform supports it. This should improve start-up time (as it's for constant Rhs matrices) and memory footprint. The transformations in arm_gemm are kernel specific. The Rhs matrix is transformed into certain layouts to improve the performance.

Resolves: COMPMID-6595

Change-Id: Id2932dd966e59f903c279417bebcea83d9a42464
Signed-off-by: Gunes Bayir <gunes.bayir@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11144
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp
index 436316c..a6c9677 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2021, 2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -221,7 +221,9 @@
         return roundup(_Nsize, strategy::out_width()) * roundup(_Ksize, strategy::k_unroll()) * _nmulti * sizeof(Toi);
     }
 
-    void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
+    void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride, bool transposed) override {
+        assert(!transposed);
+
         Toi *buffer = reinterpret_cast<Toi *>(in_buffer);
         _B_transposed = buffer;
         strategy strat(_ci);
@@ -237,7 +239,7 @@
                     const unsigned int size = roundup(xmax-x0, strategy::out_width()) * k_size;
 
                     strat.transforms.PrepareB( buffer, B + (multi * B_multi_stride), ldb,
-                                               x0, xmax, k0, kmax);
+                                               x0, xmax, k0, kmax, false);
 
                     buffer += size;
                 }
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
index 1780375..89c2d5a 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2023 Arm Limited.
+ * Copyright (c) 2017-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -631,11 +631,16 @@
         }
     }
 
-    void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
-        pretranspose_B_array_part(in_buffer, B, ldb, B_multi_stride, 0, get_B_pretranspose_window_size());
+    bool B_pretranspose_supports_transpose() const override {
+        strategy strat(_args._ci);
+        return strat.transforms.PrepareB_supports_transpose();
     }
 
-    void pretranspose_B_array_part(void *in_buffer, const To *B, const int ldb, const int B_multi_stride, size_t start, size_t end) override {
+    void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride, bool transposed) override {
+        pretranspose_B_array_part(in_buffer, B, ldb, B_multi_stride, transposed, 0, get_B_pretranspose_window_size());
+    }
+
+    void pretranspose_B_array_part(void *in_buffer, const To *B, const int ldb, const int B_multi_stride, bool transposed, size_t start, size_t end) override {
         if (end >= get_B_pretranspose_window_size()) {
             requantize_bias(in_buffer, B, ldb, B_multi_stride);
         }
@@ -717,7 +722,8 @@
                             strat.transforms.PrepareB(buffer, B + (multi * B_multi_stride), ldb,
                                                       x0, xmax,
                                                       (k_section_base * _args._Ksize) + k_offset,               // K starting point - compute row to read based on our section and the true section length.
-                                                      (k_section_base * _args._Ksize) + k_offset + k_length);   // K end point - starting point plus length computed above.
+                                                      (k_section_base * _args._Ksize) + k_offset + k_length,    // K end point - starting point plus length computed above.
+                                                      transposed);
 
                             // We need to modify our position based on the ROUNDED version of what we just did.
                             unsigned int padded_length = roundup(k_length, strategy::k_unroll());
@@ -731,7 +737,7 @@
                 } else {
                     // In the single K section case, can process the whole lot in one go.
                     strat.transforms.PrepareB(buffer, B + (multi * B_multi_stride), ldb,
-                                              n_start, n_end, k0, std::min(kmax, _args._Ksize));
+                                              n_start, n_end, k0, std::min(kmax, _args._Ksize), transposed);
                 }
             }
         }
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp
index efb5bd1..f12efe4 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2021, 2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -277,7 +277,9 @@
         }
     }
 
-    void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
+    void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride, bool transposed) override {
+        assert(!transposed);
+
         requantize_bias(in_buffer, B, ldb, B_multi_stride);
 
         uintptr_t buffer_int = reinterpret_cast<uintptr_t>(in_buffer);
@@ -296,7 +298,7 @@
                     const unsigned int size = roundup(xmax-x0, strategy::out_width()) * k_size;
 
                     strat.transforms.PrepareB( buffer, B + (multi * B_multi_stride), ldb,
-                                               x0, xmax, k0, kmax);
+                                               x0, xmax, k0, kmax, false);
 
                     buffer += size;
                 }
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
index 362a3e3..4f732f7 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2023 Arm Limited.
+ * Copyright (c) 2017-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -1067,11 +1067,18 @@
         }
     }
 
-    void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
-        pretranspose_B_array_part(in_buffer, B, ldb, B_multi_stride, 0, get_B_pretranspose_window_size());
+    // Support for transposed B is a property of the strategy::transpose type
+    bool B_pretranspose_supports_transpose() const override {
+        typename transform_type<strategy, MergeStep && std::is_same<OutputStage, Requantize32>::value>::type transforms;
+
+        return transforms.PrepareB_supports_transpose();
     }
 
-    void pretranspose_B_array_part(void *in_buffer, const To *B, const int ldb, const int B_multi_stride, size_t start, size_t end) override {
+    void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride, const bool transposed) override {
+        pretranspose_B_array_part(in_buffer, B, ldb, B_multi_stride, transposed, 0, get_B_pretranspose_window_size());
+    }
+
+    void pretranspose_B_array_part(void *in_buffer, const To *B, const int ldb, const int B_multi_stride, const bool transposed, size_t start, size_t end) override {
         // Perform column sums etc as part of the last block.
         if (end >= get_B_pretranspose_window_size()) {
             requantize_bias(in_buffer, B, ldb, B_multi_stride);
@@ -1134,7 +1141,8 @@
                         strat.transforms.PrepareB(buffer, B + (current.multi() * B_multi_stride), ldb,
                                                   x0, xmax,
                                                   (k_section_base * _Ksize) + k_offset,               // K starting point - compute row to read based on our section and the true section length.
-                                                  (k_section_base * _Ksize) + k_offset + k_length);   // K end point - starting point plus length computed above.
+                                                  (k_section_base * _Ksize) + k_offset + k_length,    // K end point - starting point plus length computed above.
+                                                  transposed);
 
                         // We need to modify our position based on the ROUNDED version of what we just did.
                         unsigned int padded_length = roundup(k_length, strategy::k_unroll());
@@ -1149,7 +1157,7 @@
                 // In the single K section case, can process the whole lot in one go.
                 // Caution: 'blockwalker::kmax()' rounds up, so clamp to valid _Ksize.
                 strat.transforms.PrepareB(buffer, B + (current.multi() * B_multi_stride), ldb,
-                                          current.x0(), current.xmax(), current.k0(), std::min(current.kmax(), _Ksize));
+                                          current.x0(), current.xmax(), current.k0(), std::min(current.kmax(), _Ksize), transposed);
                 buffer += roundup(current.xmax() - current.x0(), strategy::out_width()) * roundup(current.kmax() - current.k0(), strategy::k_unroll());
             }
 
diff --git a/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp b/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp
index 4fc9b34..ad504f2 100644
--- a/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2021, 2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -88,8 +88,8 @@
         return _subgemm->get_B_pretransposed_array_size();
     }
 
-    void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override {
-        _subgemm->pretranspose_B_array(buffer, B, ldb, B_multi_stride);
+    void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride, bool transposed) override {
+        _subgemm->pretranspose_B_array(buffer, B, ldb, B_multi_stride, transposed);
     }
 
     void set_pretransposed_B_data(void *buffer) override {
diff --git a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
index 86b33d0..f70fc98 100644
--- a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2022 Arm Limited.
+ * Copyright (c) 2017-2022, 2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -215,7 +215,9 @@
         }
     }
 
-    void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override {
+    void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride, bool transposed) override {
+        assert(!transposed);
+
         requantize_bias(buffer, B, ldb, B_multi_stride);
 
         // The actual transposed buffer goes after the column sums (if any)
@@ -225,7 +227,7 @@
         strategy strat(_args._ci);
 
         for (unsigned int multi=0; multi<_args._nmulti; multi++) {
-            strat.transforms.PrepareB(B_buffer + (multi * _buffer_per_multi), B + (multi * B_multi_stride), ldb, 0, _args._Nsize, 0, _args._Ksize);
+            strat.transforms.PrepareB(B_buffer + (multi * _buffer_per_multi), B + (multi * B_multi_stride), ldb, 0, _args._Nsize, 0, _args._Ksize, false);
         }
 
         _B_pretransposed = B_buffer;
diff --git a/src/core/NEON/kernels/arm_gemm/interleave-8way.cpp b/src/core/NEON/kernels/arm_gemm/interleave-8way.cpp
new file mode 100644
index 0000000..148678b
--- /dev/null
+++ b/src/core/NEON/kernels/arm_gemm/interleave-8way.cpp
@@ -0,0 +1,264 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifdef __aarch64__
+
+#include <arm_neon.h>
+
+#include <alloca.h>
+#include <cstring>
+
+#include "transform.hpp"
+#include "utils.hpp"
+
+namespace arm_gemm {
+
+namespace {
+
+// Helper function to interleave a single 4x4 block of 32-bin values
+// together.
+
+// _full version doesn't need to worry about any padding.
+static inline void transpose_block_32_full(const uint8_t * __restrict in_ptr0, const uint8_t * __restrict in_ptr1, const uint8_t * __restrict in_ptr2, const uint8_t * __restrict in_ptr3, uint8_t * __restrict out_ptr, long output_stride) {
+    uint32x4_t inputs[4];
+    uint32x4_t inters[4];
+    uint32x4_t outputs[4];
+
+    inputs[0] = vld1q_u32(reinterpret_cast<const uint32_t *>(in_ptr0));
+    inputs[1] = vld1q_u32(reinterpret_cast<const uint32_t *>(in_ptr1));
+    inputs[2] = vld1q_u32(reinterpret_cast<const uint32_t *>(in_ptr2));
+    inputs[3] = vld1q_u32(reinterpret_cast<const uint32_t *>(in_ptr3));
+
+    inters[0] = vzip1q_u32(inputs[0], inputs[2]);
+    inters[1] = vzip2q_u32(inputs[0], inputs[2]);
+    inters[2] = vzip1q_u32(inputs[1], inputs[3]);
+    inters[3] = vzip2q_u32(inputs[1], inputs[3]);
+
+    outputs[0] = vzip1q_u32(inters[0], inters[2]);
+    outputs[1] = vzip2q_u32(inters[0], inters[2]);
+    outputs[2] = vzip1q_u32(inters[1], inters[3]);
+    outputs[3] = vzip2q_u32(inters[1], inters[3]);
+
+    vst1q_u32(reinterpret_cast<uint32_t *>(out_ptr), outputs[0]);
+    vst1q_u32(reinterpret_cast<uint32_t *>(out_ptr + output_stride), outputs[1]);
+    vst1q_u32(reinterpret_cast<uint32_t *>(out_ptr + output_stride*2), outputs[2]);
+    vst1q_u32(reinterpret_cast<uint32_t *>(out_ptr + output_stride*3), outputs[3]);
+}
+
+// _part version: Only read "bytes_in" bytes, not a full vector.  Only write
+// out 4-byte blocks that have some live content (if bytes_in is not a
+// multiple of 4 there will some padding in each 4-block)
+static inline void transpose_block_32_part(const uint8_t *in_ptr0, const uint8_t *in_ptr1, const uint8_t *in_ptr2, const uint8_t *in_ptr3, uint8_t *out_ptr, long bytes_in, long output_stride) {
+    uint32x4_t inputs[4];
+    uint32x4_t inters[4];
+    uint32x4_t outputs[4];
+    uint8_t scratch[16] = {0};
+
+    long num_outs = iceildiv<long>(bytes_in, 4);
+
+    memcpy(scratch, in_ptr0, bytes_in);
+    inputs[0] = vld1q_u32(reinterpret_cast<const uint32_t *>(scratch));
+    memcpy(scratch, in_ptr1, bytes_in);
+    inputs[1] = vld1q_u32(reinterpret_cast<const uint32_t *>(scratch));
+    memcpy(scratch, in_ptr2, bytes_in);
+    inputs[2] = vld1q_u32(reinterpret_cast<const uint32_t *>(scratch));
+    memcpy(scratch, in_ptr3, bytes_in);
+    inputs[3] = vld1q_u32(reinterpret_cast<const uint32_t *>(scratch));
+
+    inters[0] = vzip1q_u32(inputs[0], inputs[2]);
+    inters[1] = vzip2q_u32(inputs[0], inputs[2]);
+    inters[2] = vzip1q_u32(inputs[1], inputs[3]);
+    inters[3] = vzip2q_u32(inputs[1], inputs[3]);
+
+    outputs[0] = vzip1q_u32(inters[0], inters[2]);
+    outputs[1] = vzip2q_u32(inters[0], inters[2]);
+    outputs[2] = vzip1q_u32(inters[1], inters[3]);
+    outputs[3] = vzip2q_u32(inters[1], inters[3]);
+
+    do {
+        vst1q_u32(reinterpret_cast<uint32_t *>(out_ptr), outputs[0]);
+        if (num_outs < 2)
+            break;
+        vst1q_u32(reinterpret_cast<uint32_t *>(out_ptr + output_stride), outputs[1]);
+        if (num_outs < 3)
+            break;
+        vst1q_u32(reinterpret_cast<uint32_t *>(out_ptr + output_stride*2), outputs[2]);
+        if (num_outs < 4)
+            break;
+        vst1q_u32(reinterpret_cast<uint32_t *>(out_ptr + output_stride*3), outputs[3]);
+    } while (0);
+}
+
+template<unsigned N>
+struct Unroll {
+    template<typename F>
+    static void run(F f) {
+        Unroll<N-1>::run(f);
+        f(N-1);
+    }
+};
+
+template<>
+struct Unroll<0> {
+    template<typename F>
+    static void run(F) {
+    }
+};
+
+// Interleave some multiple of 4 rows together.
+//
+// The template parameter BLOCKS controls the size of the inner loop - each BLOCK is 4 rows.
+// The function parameter interleave_multiple controls the number of times the inner loop is run.
+
+// The total interleave depth for a given run is therefore BLOCKS * interleave_multiple * 4.
+template<unsigned BLOCKS>
+void a64_interleave_1x4(uint8_t *out, const uint8_t *in, long width, long in_stride, long height, long interleave_multiple) {
+    const long total_interleave_depth = BLOCKS * 4 * interleave_multiple;
+    constexpr long loop_interleave_depth = BLOCKS * 4;
+
+    uint8_t *pad_row = reinterpret_cast<uint8_t *>(alloca(width));
+
+    if (height % total_interleave_depth) {
+        memset(pad_row, 0, width);
+    }
+
+    // Outer loop: process blocks of total_interleave_depth rows at a time.
+    for (long y0_base=0; y0_base<height; y0_base+=total_interleave_depth) {
+        // Middle loop: process each "interlave_multiple" block of rows.
+        for (long block=0; block<interleave_multiple; block++) {
+            const long y0 = y0_base + (block * loop_interleave_depth);
+            uint8_t *out_ptr = out + (block * loop_interleave_depth * 4); // 4 is the blocking depth (we interleave 4 bytes at a time from each input)
+
+            // Create and set up input row pointers.  The idea is that these
+            // should entirely fit in the register file, so we don't have to
+            // repeatedly load them (or perform the padding check)
+            const uint8_t *in_ptrs[loop_interleave_depth];
+            Unroll<loop_interleave_depth>::run( [&](unsigned y) {
+                in_ptrs[y] = (y+y0 < height) ? in + ((y+y0) * in_stride) : pad_row;
+            });
+
+            long bytes_left = width;
+            // Process full vectors using transpose_block_32_full()
+            while (bytes_left >= 16) { // 16 is the vector length in bytes
+                Unroll<BLOCKS>::run( [&](unsigned u) {
+                    transpose_block_32_full(in_ptrs[u*4 + 0],  in_ptrs[u*4 + 1],  in_ptrs[u*4 + 2],  in_ptrs[u*4 + 3],
+                                            out_ptr + 16*u, total_interleave_depth * 4); // 4 is the blocking depth
+                });
+
+                Unroll<loop_interleave_depth>::run( [&](unsigned y) {
+                    in_ptrs[y] += 16; // 16 is the vector length in bytes
+                });
+
+                out_ptr += total_interleave_depth * 16; // 16 is the vector length in bytes
+                bytes_left -= 16; // 16 is the vector length in bytes
+            }
+
+            // Process any remaining bytes using transpose_block_32_part()
+            if (bytes_left) {
+                Unroll<BLOCKS>::run( [&](unsigned u) {
+                    transpose_block_32_part(in_ptrs[u*4 + 0],  in_ptrs[u*4 + 1],  in_ptrs[u*4 + 2],  in_ptrs[u*4 + 3], 
+                                            out_ptr + 16*u, bytes_left, total_interleave_depth * 4);
+                });
+            }
+        }
+
+        // Update "out" pointer for next set of total_interleave_depth rows
+        out += total_interleave_depth * roundup<long>(width, 4);
+    }
+}
+
+} // anonymous namespace
+
+template<>
+void Transform<16, 4, false, VLType::None>(
+    uint8_t *out, const uint8_t *in, int stride, int y0, int ymax, int x0, int xmax)
+{
+    a64_interleave_1x4<4>(
+        reinterpret_cast<uint8_t *>(out),
+        reinterpret_cast<const uint8_t *>(in + y0 * stride + x0),
+        (xmax - x0),
+        stride,
+        (ymax - y0),
+        1
+    );
+}
+
+template<>
+void Transform<16, 4, false, VLType::None>(
+    int8_t *out, const int8_t *in, int stride, int y0, int ymax, int x0, int xmax)
+{
+    a64_interleave_1x4<4>(
+        reinterpret_cast<uint8_t *>(out),
+        reinterpret_cast<const uint8_t *>(in + y0 * stride + x0),
+        (xmax - x0),
+        stride,
+        (ymax - y0),
+        1
+    );
+}
+
+template<>
+void Transform<12, 1, false, VLType::None>(
+    float *out, const float *in, int stride, int y0, int ymax, int x0, int xmax)
+{
+    a64_interleave_1x4<3>(
+        reinterpret_cast<uint8_t *>(out),
+        reinterpret_cast<const uint8_t *>(in + y0 * stride + x0),
+        (xmax - x0) * sizeof(float),
+        stride * sizeof(float),
+        (ymax - y0),
+        1
+    );
+}
+
+template<>
+void Transform<16, 1, false, VLType::None>(
+    float *out, const float *in, int stride, int y0, int ymax, int x0, int xmax)
+{
+    a64_interleave_1x4<4>(
+        reinterpret_cast<uint8_t *>(out),
+        reinterpret_cast<const uint8_t *>(in + y0 * stride + x0),
+        (xmax - x0) * sizeof(float),
+        stride * sizeof(float),
+        (ymax - y0),
+        1
+    );
+}
+
+template<>
+void Transform<24, 1, false, VLType::None>(
+    float *out, const float *in, int stride, int y0, int ymax, int x0, int xmax)
+{
+    a64_interleave_1x4<3>(
+        reinterpret_cast<uint8_t *>(out),
+        reinterpret_cast<const uint8_t *>(in + y0 * stride + x0),
+        (xmax - x0) * sizeof(float),
+        stride * sizeof(float),
+        (ymax - y0),
+        2
+    );
+}
+
+} // namespace arm_gemm
+
+#endif // __aarch64__
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_4x24.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_4x24.hpp
index 171929e..bce4de7 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_4x24.hpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_4x24.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021, 2023 Arm Limited.
+ * Copyright (c) 2021, 2023-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -24,7 +24,7 @@
 #pragma once
 #ifdef __aarch64__
 
-#include "../std_transforms_fixed.hpp"
+#include "../std_transforms_fixed_trB.hpp"
 #include "../performance_parameters.hpp"
 
 #define ARGLIST  \
@@ -71,7 +71,7 @@
         return true;
     }
 
-    StdTransformsFixed<rhs_operand_type, result_type, 4, 24, 1> transforms = {};
+    StdTransformsFixedTRB<rhs_operand_type, result_type, 4, 24, 1> transforms = {};
     template<typename T>
     static inline PerformanceParameters get_performance_parameters(const CPUInfo *ci)
     {
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_6x16.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_6x16.hpp
index 759729d..7f85d2d 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_6x16.hpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_6x16.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2019-2021, 2023 Arm Limited.
+ * Copyright (c) 2019-2021, 2023-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -24,7 +24,7 @@
 #pragma once
 #ifdef __aarch64__
 
-#include "../std_transforms_fixed.hpp"
+#include "../std_transforms_fixed_trB.hpp"
 #include "../performance_parameters.hpp"
 
 #define ARGLIST  \
@@ -71,7 +71,7 @@
         return true;
     }
 
-    StdTransformsFixed<rhs_operand_type, result_type, 6, 16, 1> transforms = {};
+    StdTransformsFixedTRB<rhs_operand_type, result_type, 6, 16, 1> transforms = {};
     template<typename T>
     static inline PerformanceParameters get_performance_parameters(const CPUInfo *ci)
     {
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_8x12.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_8x12.hpp
index 65ef407..19acfe8 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_8x12.hpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_8x12.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2021, 2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -25,7 +25,7 @@
 
 #ifdef __aarch64__
 
-#include "../std_transforms_fixed.hpp"
+#include "../std_transforms_fixed_trB.hpp"
 #include "../performance_parameters.hpp"
 
 #include "../bfloat.hpp"
@@ -68,7 +68,7 @@
     }
 
     // Use the standard fixed size transforms.
-    StdTransformsFixed<operand_type, result_type, 8, 12> transforms = {};
+    StdTransformsFixedTRB<operand_type, result_type, 8, 12> transforms = {};
 
     template<typename T>
     static PerformanceParameters get_performance_parameters(const CPUInfo *ci) {
diff --git a/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp b/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp
index ce72703..d35825c 100644
--- a/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp
+++ b/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2019-2021 Arm Limited.
+ * Copyright (c) 2019-2021, 2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -184,9 +184,11 @@
         col_sums_pretransposed(B, ldb, B_multi_stride);
     }
 
-    void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override {
+    void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride, bool transposed) override {
+        assert(!transposed);
+
         uintptr_t buffer_int = reinterpret_cast<uintptr_t>(buffer);
-        _subgemm->pretranspose_B_array(reinterpret_cast<void *>(buffer_int + col_sum_size()), B, ldb, B_multi_stride);
+        _subgemm->pretranspose_B_array(reinterpret_cast<void *>(buffer_int + col_sum_size()), B, ldb, B_multi_stride, transposed);
 
         requantize_bias(buffer, B, ldb, B_multi_stride);
     }
diff --git a/src/core/NEON/kernels/arm_gemm/std_transforms_fixed.hpp b/src/core/NEON/kernels/arm_gemm/std_transforms_fixed.hpp
index 4669be9..a9cbf4e 100644
--- a/src/core/NEON/kernels/arm_gemm/std_transforms_fixed.hpp
+++ b/src/core/NEON/kernels/arm_gemm/std_transforms_fixed.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2018 Arm Limited.
+ * Copyright (c) 2018-2020, 2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -63,9 +63,14 @@
         ConvolutionInterleave<height, block, VLType::None>(out, ptr, stride, conv, rounded_stringlen, y0, ymax, k0, kmax, integrate_sums, row_sum_multiplier);
     }
 
+    bool PrepareB_supports_transpose() const {
+        return false;
+    }
+
     template<typename TIn>
     void PrepareB(TOperand *out, const TIn *in, const int stride, const int x0,
-                  const int xmax, const int k0, const int kmax) const {
+                  const int xmax, const int k0, const int kmax, bool transposed) const {
+        assert(!transposed);
         Transform<width, block,  true>(out, in, stride, x0, xmax, k0, kmax);
     }
 
diff --git a/src/core/NEON/kernels/arm_gemm/std_transforms_fixed_trB.hpp b/src/core/NEON/kernels/arm_gemm/std_transforms_fixed_trB.hpp
new file mode 100644
index 0000000..1db7164
--- /dev/null
+++ b/src/core/NEON/kernels/arm_gemm/std_transforms_fixed_trB.hpp
@@ -0,0 +1,87 @@
+/*
+ * Copyright (c) 2018-2020, 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+
+#include "convolver.hpp"
+#include "mergeresults.hpp"
+#include "transform.hpp"
+#include "interleave_indirect.hpp"
+
+namespace arm_gemm {
+
+/*
+ * Define "standard" transforms for the blocked GEMMs with fixed vector
+ * length.  This version supports accepting the RHS/B matrix in transposed
+ * format.
+ *
+ * This assumes that A is interleaved 'height' ways, B is interleaved
+ * 'width' ways and transposed, and that the merge needs to work in 'height'
+ * x 'width' blocks.
+ *
+ * The optional 'block' parameter is for kernels using dot-product type
+ * instructions like UDOT and SDOT.
+ */
+template<typename TOperand, typename TResult, unsigned int height, unsigned int width, unsigned int block=1, bool integrate_sums=false>
+class StdTransformsFixedTRB
+{
+public:
+    template<typename TIn>
+    void PrepareA(TOperand *out, const TIn *in, const int stride, const int y0,
+                  const int ymax, const int k0, const int kmax, int32_t row_sum_multiplier) const {
+        Interleave<height, block, VLType::None>(out, in, stride, y0, ymax, k0, kmax, integrate_sums, row_sum_multiplier);
+    }
+
+    template<typename TIn>
+    void PrepareA_indirect(TOperand *out, const TIn * const * const *ptr, size_t stringlen, size_t rounded_stringlen, const int y0,
+                           const int ymax, const int k0, const int kmax, int32_t row_sum_multiplier) {
+        IndirectInterleave<height, block, VLType::None>(out, ptr, stringlen, rounded_stringlen, y0, ymax, k0, kmax, integrate_sums, row_sum_multiplier);
+    }
+
+    template<typename TIn>
+    void PrepareA_convolution(TOperand *out, const TIn *ptr, size_t stride, const convolver<TIn> &conv, size_t rounded_stringlen,
+                              const int y0, const int ymax, const int k0, const int kmax, int32_t row_sum_multiplier) {
+        ConvolutionInterleave<height, block, VLType::None>(out, ptr, stride, conv, rounded_stringlen, y0, ymax, k0, kmax, integrate_sums, row_sum_multiplier);
+    }
+
+    bool PrepareB_supports_transpose() const {
+        return true;
+    }
+
+    template<typename TIn>
+    void PrepareB(TOperand *out, const TIn *in, const int stride, const int x0,
+                  const int xmax, const int k0, const int kmax, bool transposed) const {
+        if (transposed) {
+            Transform<width, block, false>(out, in, stride, x0, xmax, k0, kmax);
+        } else {
+            Transform<width, block,  true>(out, in, stride, x0, xmax, k0, kmax);
+        }
+    }
+
+    template<typename TOut>
+    void Merge(TOut *out, const TResult *in, int stride, int y0, int ymax, int x0, int xmax, const TOut *bias, const Activation act, bool append) const {
+        MergeResults<width, height>(out, in, stride, y0, ymax, x0, xmax, bias, act, append);
+    }
+};
+
+} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/std_transforms_sme.hpp b/src/core/NEON/kernels/arm_gemm/std_transforms_sme.hpp
index afe24e7..40f6162 100644
--- a/src/core/NEON/kernels/arm_gemm/std_transforms_sme.hpp
+++ b/src/core/NEON/kernels/arm_gemm/std_transforms_sme.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -60,9 +60,14 @@
         ConvolutionInterleave<height_vectors, block, VLType::SME>(out, ptr, stride, conv, rounded_stringlen, y0, ymax, k0, kmax, integrate_sums, row_sum_multiplier);
     }
 
+    bool PrepareB_supports_transpose() const {
+        return false;
+    }
+
     template<typename TIn>
     void PrepareB(TOperand *out, const TIn *in, const int stride, const int x0,
-                  const int xmax, const int k0, const int kmax) {
+                  const int xmax, const int k0, const int kmax, bool transposed) {
+        assert (!transposed);
         Transform<width_vectors, block,  true, VLType::SME>(out, in, stride, x0, xmax, k0, kmax);
     }
 
diff --git a/src/core/NEON/kernels/arm_gemm/std_transforms_sve.hpp b/src/core/NEON/kernels/arm_gemm/std_transforms_sve.hpp
index 3256d91..c516bfc 100644
--- a/src/core/NEON/kernels/arm_gemm/std_transforms_sve.hpp
+++ b/src/core/NEON/kernels/arm_gemm/std_transforms_sve.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2018 Arm Limited.
+ * Copyright (c) 2017-2018,2023-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -61,9 +61,14 @@
         ConvolutionInterleave<height, block, VLType::None>(out, ptr, stride, conv, rounded_stringlen, y0, ymax, k0, kmax, integrate_sums, row_sum_multiplier);
     }
 
+    bool PrepareB_supports_transpose() const {
+        return false;
+    }
+
     template<typename TIn>
     void PrepareB(TOperand *out, const TIn *in, const int stride, const int x0,
-                  const int xmax, const int k0, const int kmax) {
+                  const int xmax, const int k0, const int kmax, bool transposed) {
+        assert (!transposed);
         Transform<width_vectors, block,  true, VLType::SVE>(out, in, stride, x0, xmax, k0, kmax);
     }
 
diff --git a/src/core/NEON/kernels/arm_gemm/transform.cpp b/src/core/NEON/kernels/arm_gemm/transform.cpp
index 5aa62f0..45e4f0e 100644
--- a/src/core/NEON/kernels/arm_gemm/transform.cpp
+++ b/src/core/NEON/kernels/arm_gemm/transform.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -134,7 +134,14 @@
 #endif // defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
 #ifdef ARM_COMPUTE_ENABLE_BF16
 template void Transform<8, 1, true, VLType::None>(float *, const bfloat16 *, int, int, int, int, int);
-#endif
+#endif // ARM_COMPUTE_ENABLE_BF16
 #endif // AArch32
 
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
+template void Transform<12, 1, false, VLType::None>(float *, const __fp16 *, int, int, int, int, int);
+#endif // defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
+#ifdef ARM_COMPUTE_ENABLE_BF16
+template void Transform<12, 1, false, VLType::None>(float *, const bfloat16 *, int, int, int, int, int);
+#endif // ARM_COMPUTE_ENABLE_BF16
+
 } // namespace arm_gemm