blob: 941fed0ba88895b525dd5e3da9a0a8d129b19d01 [file] [log] [blame]
/*
* Copyright (c) 2018-2022, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to
* deal in the Software without restriction, including without limitation the
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
* sell copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef ACL_SRC_CPU_KERNELS_ASSEMBLY_ARM_GEMM_HPP
#define ACL_SRC_CPU_KERNELS_ASSEMBLY_ARM_GEMM_HPP
#pragma once
#include "arm_gemm_local.hpp"
#include "gemm_common.hpp"
#include <cstring>
#include <memory>
#include <vector>
namespace arm_gemm
{
enum class GemmMethod
{
DEFAULT,
GEMV_BATCHED,
GEMV_PRETRANSPOSED,
GEMV_NATIVE_TRANSPOSED,
GEMM_NATIVE,
GEMM_HYBRID,
GEMM_INTERLEAVED,
GEMM_INTERLEAVED_2D,
QUANTIZE_WRAPPER,
QUANTIZE_WRAPPER_2D,
GEMM_HYBRID_QUANTIZED
};
enum class WeightFormat
{
UNSPECIFIED = 0x1,
ANY = 0x2,
OHWI = 0x100100,
OHWIo2 = 0x100200,
OHWIo4 = 0x100400,
OHWIo8 = 0x100800,
OHWIo16 = 0x101000,
OHWIo32 = 0x102000,
OHWIo64 = 0x104000,
OHWIo128 = 0x108000,
OHWIo4i2 = 0x200400,
OHWIo4i2_bf16 = 0x200410,
OHWIo8i2 = 0x200800,
OHWIo8i2_bf16 = 0x200810,
OHWIo16i2 = 0x201000,
OHWIo16i2_bf16 = 0x201010,
OHWIo32i2 = 0x202000,
OHWIo32i2_bf16 = 0x202010,
OHWIo64i2 = 0x204000,
OHWIo64i2_bf16 = 0x204010,
OHWIo4i4 = 0x400400,
OHWIo4i4_bf16 = 0x400410,
OHWIo8i4 = 0x400800,
OHWIo8i4_bf16 = 0x400810,
OHWIo16i4 = 0x401000,
OHWIo16i4_bf16 = 0x401010,
OHWIo32i4 = 0x402000,
OHWIo32i4_bf16 = 0x402010,
OHWIo64i4 = 0x404000,
OHWIo64i4_bf16 = 0x404010,
OHWIo2i8 = 0x800200,
OHWIo4i8 = 0x800400,
OHWIo8i8 = 0x800800,
OHWIo16i8 = 0x801000,
OHWIo32i8 = 0x802000,
OHWIo64i8 = 0x804000
};
struct KernelDescription
{
GemmMethod method = GemmMethod::DEFAULT;
std::string name = "";
bool is_default = false;
uint64_t cycle_estimate = 0;
KernelDescription(GemmMethod m, std::string n, bool d = false, uint64_t c = 0)
: method(m), name(n), is_default(d), cycle_estimate(c)
{
}
KernelDescription() noexcept
{
}
};
struct GemmConfig
{
GemmMethod method = GemmMethod::DEFAULT;
std::string filter = "";
unsigned int inner_block_size = 0;
unsigned int outer_block_size = 0;
WeightFormat weight_format = WeightFormat::ANY;
GemmConfig(GemmMethod method) : method(method)
{
}
GemmConfig()
{
}
};
struct Activation
{
enum class Type
{
None,
ReLU,
BoundedReLU
};
Type type;
float param1;
float param2;
Activation(Type type = Type::None, float p1 = 0.0f, float p2 = 0.0f) : type(type), param1(p1), param2(p2)
{
}
};
struct GemmArgs
{
public:
const CPUInfo *_ci;
unsigned int _Msize; // num of tiles
unsigned int _Nsize; // output channels
unsigned int _Ksize; // input channels
unsigned int _Ksections;
unsigned int _nbatches;
unsigned int _nmulti; // n_gemms to be performed
bool _indirect_input;
Activation _act;
int _maxthreads;
bool _fixed_format;
bool _fast_mode;
bool _accumulate;
const GemmConfig *_cfg;
GemmArgs(const CPUInfo *ci,
unsigned int M,
unsigned int N,
unsigned int K,
unsigned int Ksections,
unsigned int nbatches,
unsigned int nmulti,
bool indirect_input,
Activation act,
const int maxthreads,
bool fixed_format = false,
bool fast_mode = false,
bool accumulate = false,
const GemmConfig *cfg = nullptr)
: _ci(ci),
_Msize(M),
_Nsize(N),
_Ksize(K),
_Ksections(Ksections),
_nbatches(nbatches),
_nmulti(nmulti),
_indirect_input(indirect_input),
_act(act),
_maxthreads(maxthreads),
_fixed_format(fixed_format),
_fast_mode(fast_mode),
_accumulate(accumulate),
_cfg(cfg)
{
}
};
struct Requantize32
{
public:
const int32_t *bias = nullptr;
size_t bias_multi_stride = 0;
int32_t a_offset = 0;
int32_t b_offset = 0;
int32_t c_offset = 0;
bool per_channel_requant = false;
int32_t per_layer_left_shift = 0;
int32_t per_layer_right_shift = 0;
int32_t per_layer_mul = 0;
const int32_t *per_channel_left_shifts = nullptr;
const int32_t *per_channel_right_shifts = nullptr;
const int32_t *per_channel_muls = nullptr;
int32_t minval = 0;
int32_t maxval = 0;
Requantize32() = default;
// Constructor for per-tensor quantization
Requantize32(const int32_t *bias,
size_t bias_multi_stride,
int32_t a_offset,
int32_t b_offset,
int32_t c_offset,
int32_t requant_shift,
int32_t requant_mul,
int32_t minv,
int32_t maxv)
: bias(bias),
bias_multi_stride(bias_multi_stride),
a_offset(a_offset),
b_offset(b_offset),
c_offset(c_offset),
per_channel_requant(false),
per_layer_left_shift(std::max<int32_t>(requant_shift, 0)),
per_layer_right_shift(std::min<int32_t>(requant_shift, 0)),
per_layer_mul(requant_mul),
minval(minv),
maxval(maxv)
{
}
// Constructor for per-channel quantization
Requantize32(const int32_t *bias,
size_t bias_multi_stride,
int32_t a_offset,
int32_t b_offset,
int32_t c_offset,
const int32_t *requant_left_shifts,
const int32_t *requant_right_shifts,
const int32_t *requant_muls,
int32_t minv,
int32_t maxv)
: bias(bias),
bias_multi_stride(bias_multi_stride),
a_offset(a_offset),
b_offset(b_offset),
c_offset(c_offset),
per_channel_requant(true),
per_channel_left_shifts(requant_left_shifts),
per_channel_right_shifts(requant_right_shifts),
per_channel_muls(requant_muls),
minval(minv),
maxval(maxv)
{
}
};
struct DequantizeFloat
{
public:
float scale = 0;
DequantizeFloat() = default;
// Constructor
DequantizeFloat(const float scale) : scale(scale)
{
}
};
struct Nothing
{
};
template <typename Top, typename Tret>
using UniqueGemmCommon = std::unique_ptr<GemmCommon<Top, Tret>>;
/* Low level API calls.
* These are implemented as 'GemmArgs' versions, or with the arguments explicitly listed. */
/* get_gemm_method(): Given the templated types and provided parameters,
* which is the preferred method to implement this GEMM? */
template <typename Top, typename Tret, class OutputStage = Nothing>
KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage & = {});
template <typename Top, typename Tret, class OutputStage = Nothing>
UniqueGemmCommon<Top, Tret> gemm(const GemmArgs &args, const OutputStage & = {});
template <typename Top, typename Tret, class OutputStage = Nothing>
std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, const OutputStage & = {});
template <typename Top, typename Tret, class OutputStage = Nothing>
bool has_opt_gemm(WeightFormat &weight_format, const GemmArgs &args, const OutputStage & = {});
} // namespace arm_gemm
#endif // ACL_SRC_CPU_KERNELS_ASSEMBLY_ARM_GEMM_HPP