Fix for inclusion of "arm_gemm" from src into "Types.h" from core
- Added arm_compute::WeightFormat and converted to/from arm_gemm::WeightFormat
when needed through two map function.
- Moved to_string(WeightFormat) to TypePrinter.h
Resolves: COMPMID-5415
Signed-off-by: Ramy Elgammal <ramy.elgammal@arm.com>
Change-Id: I65f7942100bcd4dbf2c5cf6c07f26c8e1e3bf86e
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/c/VisualCompute/ComputeLibrary/+/438511
Tested-by: bsgcomp <bsgcomp@arm.com>
Reviewed-by: Pablo Tello <pablo.tello@arm.com>
Reviewed-by: Sicong Li <sicong.li@arm.com>
Comments-Addressed: bsgcomp <bsgcomp@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7985
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h
index 989cdfb..c87c97c 100644
--- a/arm_compute/core/Types.h
+++ b/arm_compute/core/Types.h
@@ -32,7 +32,6 @@
#include "arm_compute/core/TensorShape.h"
#include "arm_compute/core/experimental/IPostOp.h"
#include "arm_compute/core/utils/misc/Macros.h"
-#include "src/cpu/kernels/assembly/arm_gemm.hpp"
#include "support/Bfloat16.h"
#include "support/Half.h"
@@ -775,10 +774,10 @@
private:
std::pair<unsigned int, unsigned int> _stride;
- unsigned int _pad_left;
- unsigned int _pad_top;
- unsigned int _pad_right;
- unsigned int _pad_bottom;
+ unsigned int _pad_left;
+ unsigned int _pad_top;
+ unsigned int _pad_right;
+ unsigned int _pad_bottom;
DimensionRoundingType _round_type;
};
@@ -920,14 +919,14 @@
}
private:
- std::vector<float> _min_sizes;
- std::vector<float> _variances;
- float _offset;
- bool _flip;
- bool _clip;
- std::vector<float> _max_sizes;
- std::vector<float> _aspect_ratios;
- Coordinates2D _img_size;
+ std::vector<float> _min_sizes;
+ std::vector<float> _variances;
+ float _offset;
+ bool _flip;
+ bool _clip;
+ std::vector<float> _max_sizes;
+ std::vector<float> _aspect_ratios;
+ Coordinates2D _img_size;
std::array<float, 2> _steps;
};
@@ -1172,15 +1171,15 @@
}
private:
- unsigned int _max_detections;
- unsigned int _max_classes_per_detection;
- float _nms_score_threshold;
- float _iou_threshold;
- unsigned int _num_classes;
+ unsigned int _max_detections;
+ unsigned int _max_classes_per_detection;
+ float _nms_score_threshold;
+ float _iou_threshold;
+ unsigned int _num_classes;
std::array<float, 4> _scales_values;
- bool _use_regular_nms;
- unsigned int _detection_per_class;
- bool _dequantize_scores;
+ bool _use_regular_nms;
+ unsigned int _detection_per_class;
+ bool _dequantize_scores;
};
/** Pooling Layer Information struct*/
@@ -1613,13 +1612,13 @@
}
private:
- float _img_width;
- float _img_height;
- float _scale;
- bool _apply_scale;
- bool _correct_transform_coords;
+ float _img_width;
+ float _img_height;
+ float _scale;
+ bool _apply_scale;
+ bool _correct_transform_coords;
std::array<float, 4> _weights;
- float _bbox_xform_clip;
+ float _bbox_xform_clip;
};
/** Activation Layer Information class */
@@ -1895,13 +1894,117 @@
int32_t _shrink_axis_mask;
};
+/** Memory layouts for the weights tensor.
+ *
+ * * UNSPECIFIED is used to select kernels that do not run in
+ * variable weights mode.
+ *
+ * * ANY is used to query the kernel database to retrieve any of the
+ * kernels that runs in variable weights mode. Once a kernel is
+ * found, the specific format expected by the kernel can be
+ * retrieved by the user for reordering the weights tensor
+ * accordingly.
+ *
+ * The other values OHWIo{interleave_by}i{block_by} describe the
+ * memory layout of a 4D tensor with layout OHWI that has been
+ * transformed into a 4D tensor with dimensions O'HWI' where:
+ *
+ * O' = first multiple of {interleave_by} s.t. O<=O'
+ * I' = first multiple of {block_by} s.t. I<=I'
+ *
+ * The total size of the dst tensor is O' x H x W x I'
+ *
+ * The access function of the tensor with layout
+ * OHWIo{interleave_by}i{block_by} and size O'HWI' is a 6-parameter
+ * access function, where the 6 parameters are computed as follows:
+ *
+ * x5 = floor(o/{interleave_by}) RANGE [0, O'/{interleave_by} -1] SIZE: O'/{interleave_by}
+ *
+ * x4 = h RANGE [0, H-1] SIZE: H
+ * x3 = w RANGE [0, W-1] SIZE: W
+ * x2 = floor(i/{block_by}) RANGE [0, I'/{block_by} -1] SIZE: I'/{block_by}
+ * x1 = o%{interleave_by} RANGE [0, {interleave_by} -1] SIZE: {interleave_by}
+ * x0 = i%{block_by} RANGE [0, {block_by} -1] SIZE: {block_by}
+ * TOTAL SIZE: O' * H * W * I'
+ *
+ * 4D 6D
+ * ----------------- -----------------------------------
+ * value(o, h, w, i) = x5 * H * W * I' * {interleave_by}
+ * + x4 * W * I' * {interleave_by}
+ * + x3 * I' * {interleave_by}
+ * + x2 * {interleave_by} * {block_by}
+ * + x1 * {block_by}
+ * + x0
+ *
+ * Notice that in arm_gemm the 4D tensor of dimension O'HWI' created
+ * for the OHWIo{interleave_by}i{block_by} format is in reality seen
+ * as a 2D tensor, where the number of rows is O'/{interleave_by}
+ * and the number of columns is {interleave_by} * H * W * I'.
+ *
+ * The postfix *_bf16 is for the memory layout needed for the
+ * fast-mode kernels, in which the weights are passed in bfloat16
+ * format.
+ */
+enum class WeightFormat
+{
+ UNSPECIFIED = 0x1,
+ ANY = 0x2,
+ OHWI = 0x100100,
+ OHWIo2 = 0x100200,
+ OHWIo4 = 0x100400,
+ OHWIo8 = 0x100800,
+ OHWIo16 = 0x101000,
+ OHWIo32 = 0x102000,
+ OHWIo64 = 0x104000,
+ OHWIo128 = 0x108000,
+ OHWIo4i2 = 0x200400,
+ OHWIo4i2_bf16 = 0x200410,
+ OHWIo8i2 = 0x200800,
+ OHWIo8i2_bf16 = 0x200810,
+ OHWIo16i2 = 0x201000,
+ OHWIo16i2_bf16 = 0x201010,
+ OHWIo32i2 = 0x202000,
+ OHWIo32i2_bf16 = 0x202010,
+ OHWIo64i2 = 0x204000,
+ OHWIo64i2_bf16 = 0x204010,
+ OHWIo4i4 = 0x400400,
+ OHWIo4i4_bf16 = 0x400410,
+ OHWIo8i4 = 0x400800,
+ OHWIo8i4_bf16 = 0x400810,
+ OHWIo16i4 = 0x401000,
+ OHWIo16i4_bf16 = 0x401010,
+ OHWIo32i4 = 0x402000,
+ OHWIo32i4_bf16 = 0x402010,
+ OHWIo64i4 = 0x404000,
+ OHWIo64i4_bf16 = 0x404010,
+ OHWIo2i8 = 0x800200,
+ OHWIo4i8 = 0x800400,
+ OHWIo8i8 = 0x800800,
+ OHWIo16i8 = 0x801000,
+ OHWIo32i8 = 0x802000,
+ OHWIo64i8 = 0x804000
+};
+// OHWIo<interleave_by>i<block_by>
+inline int interleave_by(const WeightFormat wf)
+{
+ return ((int)wf >> 8) & 0xFFF;
+}
+inline int block_by(const WeightFormat wf)
+{
+ return ((int)wf >> 20) & 0xF;
+}
+inline bool is_fixed_format(const WeightFormat wf)
+{
+ return wf != WeightFormat::UNSPECIFIED && wf != WeightFormat::ANY;
+}
+
/** Convolution Layer Weights Information class. This class stores the necessary information to compute convolution layer when the weights are already reshaped */
class WeightsInfo
{
public:
/** Default constructor */
WeightsInfo()
- : _are_reshaped(false), _kernel_width(0), _kernel_height(0), _num_kernels(0), _retain_internal_weights(false), _weight_format(arm_gemm::WeightFormat::UNSPECIFIED)
+ : _are_reshaped(false), _kernel_width(0), _kernel_height(0), _num_kernels(0), _retain_internal_weights(false), _weight_format(arm_compute::WeightFormat::UNSPECIFIED)
{
}
/** Constructor
@@ -1911,10 +2014,10 @@
* @param[in] kernel_height Kernel height.
* @param[in] num_kernels Number of convolution kernels.
* @param[in] retain_internal_weights (Optional) True if internal reshaped weights must be retained. Used for reconfiguration purposes. Default is false.
- * @param[in] weight_format (Optional) arm_gemm:WeightFormat enumeration requested by the user. Default is arm_gemm::WeightFormat::UNSPECIFIED.
+ * @param[in] weight_format (Optional) arm_gemm:WeightFormat enumeration requested by the user. Default is arm_compute::WeightFormat::UNSPECIFIED.
*/
WeightsInfo(bool are_reshaped, unsigned int kernel_width, unsigned int kernel_height, unsigned int num_kernels, bool retain_internal_weights = false,
- arm_gemm::WeightFormat weight_format = arm_gemm::WeightFormat::UNSPECIFIED)
+ arm_compute::WeightFormat weight_format = arm_compute::WeightFormat::UNSPECIFIED)
: _are_reshaped(are_reshaped), _kernel_width(kernel_width), _kernel_height(kernel_height), _num_kernels(num_kernels), _retain_internal_weights(retain_internal_weights), _weight_format(weight_format)
{
}
@@ -1946,7 +2049,7 @@
{
return _retain_internal_weights;
}
- arm_gemm::WeightFormat weight_format() const
+ arm_compute::WeightFormat weight_format() const
{
return _weight_format;
}
@@ -1960,12 +2063,12 @@
}
private:
- bool _are_reshaped;
- unsigned int _kernel_width;
- unsigned int _kernel_height;
- unsigned int _num_kernels;
- bool _retain_internal_weights;
- arm_gemm::WeightFormat _weight_format;
+ bool _are_reshaped;
+ unsigned int _kernel_width;
+ unsigned int _kernel_height;
+ unsigned int _num_kernels;
+ bool _retain_internal_weights;
+ arm_compute::WeightFormat _weight_format;
};
/** GEMM reshape information class. This class stores the necessary information about matrix A and matrix B reshape.
@@ -2177,7 +2280,7 @@
_activation_info(),
_post_ops(),
_fixed_format(false),
- _weight_format(arm_gemm::WeightFormat::UNSPECIFIED)
+ _weight_format(arm_compute::WeightFormat::UNSPECIFIED)
{
}
/** Constructor
@@ -2196,13 +2299,13 @@
* @param[in] broadcast_bias (Optional) Broadcast the shape of the bias tensor from a vector to a matrix.
* @param[in] activation_info (Optional) Activation to apply after the matrix multiplication
* @param[in] post_ops (Optional) A sequence of post operations that are performed after the main operation.
- * @param[in] fixed_format (Optional) Specify the selection of fixed format kernels for variable weights support in GEMM. These kernels expect the weights tensor to be in amemory format that is fixed by the kernel itself. For more information, see arm_gemm::WeightFormat.
- * @param[in] weight_format (Optional) arm_gemm:WeightFormat enumeration requested by the user. Default is arm_gemm::WeightFormat::UNSPECIFIED.
+ * @param[in] fixed_format (Optional) Specify the selection of fixed format kernels for variable weights support in GEMM. These kernels expect the weights tensor to be in amemory format that is fixed by the kernel itself. For more information, see arm_compute::WeightFormat.
+ * @param[in] weight_format (Optional) arm_gemm:WeightFormat enumeration requested by the user. Default is arm_compute::WeightFormat::UNSPECIFIED.
*/
GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, int depth_output_gemm3d = 0, bool reinterpret_input_as_3d = false, bool retain_internal_weights = false,
GEMMLowpOutputStageInfo gemmlowp_output_stage = GEMMLowpOutputStageInfo(), bool fp_mixed_precision = false, bool fast_math = false, bool broadcast_bias = false,
const ActivationLayerInfo &activation_info = ActivationLayerInfo(), const experimental::PostOpList<ITensorInfo *> &post_ops = experimental::PostOpList<ITensorInfo *>(),
- bool fixed_format = false, arm_gemm::WeightFormat weight_format = arm_gemm::WeightFormat::UNSPECIFIED) noexcept
+ bool fixed_format = false, arm_compute::WeightFormat weight_format = arm_compute::WeightFormat::UNSPECIFIED) noexcept
: _is_a_reshaped(is_a_reshaped),
_is_b_reshaped(is_b_reshaped),
_reshape_b_only_on_first_run(reshape_b_only_on_first_run),
@@ -2392,7 +2495,7 @@
return _fixed_format;
}
- arm_gemm::WeightFormat weight_format() const
+ arm_compute::WeightFormat weight_format() const
{
return _weight_format;
}
@@ -2413,7 +2516,7 @@
ActivationLayerInfo _activation_info;
experimental::PostOpList<ITensorInfo *> _post_ops;
bool _fixed_format;
- arm_gemm::WeightFormat _weight_format;
+ arm_compute::WeightFormat _weight_format;
};
/** Winograd information */
diff --git a/arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h b/arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h
index 2af11ad..a282662 100644
--- a/arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h
+++ b/arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h
@@ -131,7 +131,7 @@
*
* The user can query the database of optimised kernels in
* arm_gemm by specifying one of the enumerations of
- * arm_gemm::WeightFormat in the weight_format field of the input
+ * arm_compute::WeightFormat in the weight_format field of the input
* parameter weights_info. In case of success, the method
* writes the expected format in the output parameter
* expected_weight_format. The expected_weight_format can than be
@@ -140,7 +140,7 @@
*
* Use case one - query for a specific format:
*
- * WeightInfo weights_info(..., arm_gemm::WeightFormat::OHWIo4, ...); // Set the value of the input query.
+ * WeightInfo weights_info(..., arm_compute::WeightFormat::OHWIo4, ...); // Set the value of the input query.
* if (NEGEMMConvolutionlayer::has_opt_impl(WeightFormat(), ...., weights_info, ...))
* {
* auto conv = std::unique_ptr<NEGEMMConvolutionlayer>();
@@ -150,8 +150,8 @@
*
* Use case two - query for any format that would be optimal for the GEMM to execute:
*
- * WeightInfo weights_info(..., arm_gemm::WeightFormat::ANY, ...); // Set the value of the input query.
- * arm_gemm::WeightFormat expected_wf;
+ * WeightInfo weights_info(..., arm_compute::WeightFormat::ANY, ...); // Set the value of the input query.
+ * arm_compute::WeightFormat expected_wf;
* if (NEGEMMConvolutionlayer::has_opt_impl(expected_wf, ...., weights_info, ...))
* {
* auto conv = std::unique_ptr<NEGEMMConvolutionlayer>();
@@ -177,7 +177,7 @@
*
* @return a Status
*/
- static Status has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst,
+ static Status has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst,
const PadStrideInfo &conv_info,
const WeightsInfo &weights_info = WeightsInfo(), const Size2D &dilation = Size2D(1U, 1U), const ActivationLayerInfo &act_info = ActivationLayerInfo(),
bool enable_fast_math = false);
diff --git a/src/core/utils/AssemblyUtils.cpp b/src/core/utils/AssemblyUtils.cpp
index 1e8a2a5..45e7ff7 100644
--- a/src/core/utils/AssemblyUtils.cpp
+++ b/src/core/utils/AssemblyUtils.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -66,5 +66,245 @@
pad_stride_info.pad_right(),
pad_stride_info.pad_bottom() };
}
+
+arm_gemm::WeightFormat map_to_arm_gemm_weight_format(const arm_compute::WeightFormat &weight_format)
+{
+ arm_gemm::WeightFormat gemm_weight_fromat;
+
+ switch(weight_format)
+ {
+ case arm_compute::WeightFormat::UNSPECIFIED:
+ gemm_weight_fromat = arm_gemm::WeightFormat::UNSPECIFIED;
+ break;
+ case arm_compute::WeightFormat::ANY:
+ gemm_weight_fromat = arm_gemm::WeightFormat::ANY;
+ break;
+ case arm_compute::WeightFormat::OHWI:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWI;
+ break;
+ case arm_compute::WeightFormat::OHWIo2:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo2;
+ break;
+ case arm_compute::WeightFormat::OHWIo4:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4;
+ break;
+ case arm_compute::WeightFormat::OHWIo8:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8;
+ break;
+ case arm_compute::WeightFormat::OHWIo16:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16;
+ break;
+ case arm_compute::WeightFormat::OHWIo32:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32;
+ break;
+ case arm_compute::WeightFormat::OHWIo64:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64;
+ break;
+ case arm_compute::WeightFormat::OHWIo128:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo128;
+ break;
+ case arm_compute::WeightFormat::OHWIo4i2:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i2;
+ break;
+ case arm_compute::WeightFormat::OHWIo4i2_bf16:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i2_bf16;
+ break;
+ case arm_compute::WeightFormat::OHWIo8i2:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i2;
+ break;
+ case arm_compute::WeightFormat::OHWIo8i2_bf16:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i2_bf16;
+ break;
+ case arm_compute::WeightFormat::OHWIo16i2:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i2;
+ break;
+ case arm_compute::WeightFormat::OHWIo16i2_bf16:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i2_bf16;
+ break;
+ case arm_compute::WeightFormat::OHWIo32i2:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i2;
+ break;
+ case arm_compute::WeightFormat::OHWIo32i2_bf16:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i2_bf16;
+ break;
+ case arm_compute::WeightFormat::OHWIo64i2:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i2;
+ break;
+ case arm_compute::WeightFormat::OHWIo64i2_bf16:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i2_bf16;
+ break;
+ case arm_compute::WeightFormat::OHWIo4i4:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i4;
+ break;
+ case arm_compute::WeightFormat::OHWIo4i4_bf16:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i4_bf16;
+ break;
+ case arm_compute::WeightFormat::OHWIo8i4:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i4;
+ break;
+ case arm_compute::WeightFormat::OHWIo8i4_bf16:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i4_bf16;
+ break;
+ case arm_compute::WeightFormat::OHWIo16i4:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i4;
+ break;
+ case arm_compute::WeightFormat::OHWIo16i4_bf16:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i4_bf16;
+ break;
+ case arm_compute::WeightFormat::OHWIo32i4:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i4;
+ break;
+ case arm_compute::WeightFormat::OHWIo32i4_bf16:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i4_bf16;
+ break;
+ case arm_compute::WeightFormat::OHWIo64i4:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i4;
+ break;
+ case arm_compute::WeightFormat::OHWIo64i4_bf16:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i4_bf16;
+ break;
+ case arm_compute::WeightFormat::OHWIo2i8:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo2i8;
+ break;
+ case arm_compute::WeightFormat::OHWIo4i8:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i8;
+ break;
+ case arm_compute::WeightFormat::OHWIo8i8:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i8;
+ break;
+ case arm_compute::WeightFormat::OHWIo16i8:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i8;
+ break;
+ case arm_compute::WeightFormat::OHWIo32i8:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i8;
+ break;
+ case arm_compute::WeightFormat::OHWIo64i8:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i8;
+ break;
+ default:
+ gemm_weight_fromat = arm_gemm::WeightFormat::UNSPECIFIED;
+ }
+ return gemm_weight_fromat;
+}
+
+arm_compute::WeightFormat map_to_arm_compute_weight_format(const arm_gemm::WeightFormat &weight_format)
+{
+ arm_compute::WeightFormat acl_weight_fromat;
+
+ switch(weight_format)
+ {
+ case arm_gemm::WeightFormat::UNSPECIFIED:
+ acl_weight_fromat = arm_compute::WeightFormat::UNSPECIFIED;
+ break;
+ case arm_gemm::WeightFormat::ANY:
+ acl_weight_fromat = arm_compute::WeightFormat::ANY;
+ break;
+ case arm_gemm::WeightFormat::OHWI:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWI;
+ break;
+ case arm_gemm::WeightFormat::OHWIo2:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo2;
+ break;
+ case arm_gemm::WeightFormat::OHWIo4:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo4;
+ break;
+ case arm_gemm::WeightFormat::OHWIo8:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo8;
+ break;
+ case arm_gemm::WeightFormat::OHWIo16:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo16;
+ break;
+ case arm_gemm::WeightFormat::OHWIo32:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo32;
+ break;
+ case arm_gemm::WeightFormat::OHWIo64:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo64;
+ break;
+ case arm_gemm::WeightFormat::OHWIo128:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo128;
+ break;
+ case arm_gemm::WeightFormat::OHWIo4i2:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i2;
+ break;
+ case arm_gemm::WeightFormat::OHWIo4i2_bf16:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i2_bf16;
+ break;
+ case arm_gemm::WeightFormat::OHWIo8i2:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i2;
+ break;
+ case arm_gemm::WeightFormat::OHWIo8i2_bf16:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i2_bf16;
+ break;
+ case arm_gemm::WeightFormat::OHWIo16i2:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i2;
+ break;
+ case arm_gemm::WeightFormat::OHWIo16i2_bf16:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i2_bf16;
+ break;
+ case arm_gemm::WeightFormat::OHWIo32i2:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i2;
+ break;
+ case arm_gemm::WeightFormat::OHWIo32i2_bf16:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i2_bf16;
+ break;
+ case arm_gemm::WeightFormat::OHWIo64i2:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i2;
+ break;
+ case arm_gemm::WeightFormat::OHWIo64i2_bf16:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i2_bf16;
+ break;
+ case arm_gemm::WeightFormat::OHWIo4i4:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i4;
+ break;
+ case arm_gemm::WeightFormat::OHWIo4i4_bf16:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i4_bf16;
+ break;
+ case arm_gemm::WeightFormat::OHWIo8i4:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i4;
+ break;
+ case arm_gemm::WeightFormat::OHWIo8i4_bf16:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i4_bf16;
+ break;
+ case arm_gemm::WeightFormat::OHWIo16i4:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i4;
+ break;
+ case arm_gemm::WeightFormat::OHWIo16i4_bf16:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i4_bf16;
+ break;
+ case arm_gemm::WeightFormat::OHWIo32i4:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i4;
+ break;
+ case arm_gemm::WeightFormat::OHWIo32i4_bf16:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i4_bf16;
+ break;
+ case arm_gemm::WeightFormat::OHWIo64i4:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i4;
+ break;
+ case arm_gemm::WeightFormat::OHWIo64i4_bf16:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i4_bf16;
+ break;
+ case arm_gemm::WeightFormat::OHWIo2i8:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo2i8;
+ break;
+ case arm_gemm::WeightFormat::OHWIo4i8:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i8;
+ break;
+ case arm_gemm::WeightFormat::OHWIo8i8:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i8;
+ break;
+ case arm_gemm::WeightFormat::OHWIo16i8:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i8;
+ break;
+ case arm_gemm::WeightFormat::OHWIo32i8:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i8;
+ break;
+ case arm_gemm::WeightFormat::OHWIo64i8:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i8;
+ break;
+ default:
+ acl_weight_fromat = arm_compute::WeightFormat::UNSPECIFIED;
+ }
+ return acl_weight_fromat;
+}
} // namespace assembly_utils
} // namespace arm_compute
diff --git a/src/core/utils/AssemblyUtils.h b/src/core/utils/AssemblyUtils.h
index b1aee64..7514175 100644
--- a/src/core/utils/AssemblyUtils.h
+++ b/src/core/utils/AssemblyUtils.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -47,6 +47,22 @@
* @return Assembly padding values.
*/
arm_conv::PaddingValues map_to_arm_conv_padding(const PadStrideInfo &pad_stride_info);
+
+/** Performs a mapping from Compute Library WeightFormat to the assembly WeightFormat enum
+ *
+ * @param[in] weight_format Compute Library WeightFormat enum value
+ *
+ * @return Assembly WeightFormat
+ */
+arm_gemm::WeightFormat map_to_arm_gemm_weight_format(const arm_compute::WeightFormat &weight_format);
+
+/** Performs a mapping from Assembly WeightFormat to the Compute Library WeightFormat enum
+ *
+ * @param[in] weight_format Assembly WeightFormat enum value
+ *
+ * @return Compute Library WeightFormat
+ */
+arm_compute::WeightFormat map_to_arm_compute_weight_format(const arm_gemm::WeightFormat &weight_format);
} // namespace assembly
} // namespace arm_compute
#endif /* UTILS_CORE_ASSEMBLY_UTILS_H */
diff --git a/src/cpu/kernels/assembly/arm_gemm.hpp b/src/cpu/kernels/assembly/arm_gemm.hpp
index 48fd7c6..4c127b4 100644
--- a/src/cpu/kernels/assembly/arm_gemm.hpp
+++ b/src/cpu/kernels/assembly/arm_gemm.hpp
@@ -47,57 +47,6 @@
GEMM_HYBRID_QUANTIZED
};
-/** Memory layouts for the weights tensor.
- *
- * * UNSPECIFIED is used to select kernels that do not run in
- * variable weights mode.
- *
- * * ANY is used to query the kernel database to retrieve any of the
- * kernels that runs in variable weights mode. Once a kernel is
- * found, the specific format expected by the kernel can be
- * retrieved by the user for reordering the weights tensor
- * accordingly.
- *
- * The other values OHWIo{interleave_by}i{block_by} describe the
- * memory layout of a 4D tensor with layout OHWI that has been
- * transformed into a 4D tensor with dimensions O'HWI' where:
- *
- * O' = first multiple of {interleave_by} s.t. O<=O'
- * I' = first multiple of {block_by} s.t. I<=I'
- *
- * The total size of the dst tensor is O' x H x W x I'
- *
- * The access function of the tensor with layout
- * OHWIo{interleave_by}i{block_by} and size O'HWI' is a 6-parameter
- * access function, where the 6 parameters are computed as follows:
- *
- * x5 = floor(o/{interleave_by}) RANGE [0, O'/{interleave_by} -1] SIZE: O'/{interleave_by}
- *
- * x4 = h RANGE [0, H-1] SIZE: H
- * x3 = w RANGE [0, W-1] SIZE: W
- * x2 = floor(i/{block_by}) RANGE [0, I'/{block_by} -1] SIZE: I'/{block_by}
- * x1 = o%{interleave_by} RANGE [0, {interleave_by} -1] SIZE: {interleave_by}
- * x0 = i%{block_by} RANGE [0, {block_by} -1] SIZE: {block_by}
- * TOTAL SIZE: O' * H * W * I'
- *
- * 4D 6D
- * ----------------- -----------------------------------
- * value(o, h, w, i) = x5 * H * W * I' * {interleave_by}
- * + x4 * W * I' * {interleave_by}
- * + x3 * I' * {interleave_by}
- * + x2 * {interleave_by} * {block_by}
- * + x1 * {block_by}
- * + x0
- *
- * Notice that in arm_gemm the 4D tensor of dimension O'HWI' created
- * for the OHWIo{interleave_by}i{block_by} format is in reality seen
- * as a 2D tensor, where the number of rows is O'/{interleave_by}
- * and the number of columns is {interleave_by} * H * W * I'.
- *
- * The postfix *_bf16 is for the memory layout needed for the
- * fast-mode kernels, in which the weights are passed in bfloat16
- * format.
- */
enum class WeightFormat
{
UNSPECIFIED = 0x1,
@@ -138,69 +87,6 @@
OHWIo64i8 = 0x804000
};
-// OHWIo<interleave_by>i<block_by>
-inline int interleave_by(const WeightFormat wf)
-{
- return ((int)wf >> 8) & 0xFFF;
-}
-inline int block_by(const WeightFormat wf)
-{
- return ((int)wf >> 20) & 0xF;
-}
-inline bool is_fixed_format(const WeightFormat wf)
-{
- return wf != WeightFormat::UNSPECIFIED && wf != WeightFormat::ANY;
-}
-
-inline std::string to_string(WeightFormat wf)
-{
-#define __CASE_WEIGHT_FORMAT(wf) \
-case WeightFormat::wf: \
- return #wf;
- switch(wf)
- {
- __CASE_WEIGHT_FORMAT(UNSPECIFIED)
- __CASE_WEIGHT_FORMAT(ANY)
- __CASE_WEIGHT_FORMAT(OHWI)
- __CASE_WEIGHT_FORMAT(OHWIo2)
- __CASE_WEIGHT_FORMAT(OHWIo4)
- __CASE_WEIGHT_FORMAT(OHWIo8)
- __CASE_WEIGHT_FORMAT(OHWIo16)
- __CASE_WEIGHT_FORMAT(OHWIo32)
- __CASE_WEIGHT_FORMAT(OHWIo64)
- __CASE_WEIGHT_FORMAT(OHWIo128)
- __CASE_WEIGHT_FORMAT(OHWIo4i2)
- __CASE_WEIGHT_FORMAT(OHWIo4i2_bf16)
- __CASE_WEIGHT_FORMAT(OHWIo8i2)
- __CASE_WEIGHT_FORMAT(OHWIo8i2_bf16)
- __CASE_WEIGHT_FORMAT(OHWIo16i2)
- __CASE_WEIGHT_FORMAT(OHWIo16i2_bf16)
- __CASE_WEIGHT_FORMAT(OHWIo32i2)
- __CASE_WEIGHT_FORMAT(OHWIo32i2_bf16)
- __CASE_WEIGHT_FORMAT(OHWIo64i2)
- __CASE_WEIGHT_FORMAT(OHWIo64i2_bf16)
- __CASE_WEIGHT_FORMAT(OHWIo4i4)
- __CASE_WEIGHT_FORMAT(OHWIo4i4_bf16)
- __CASE_WEIGHT_FORMAT(OHWIo8i4)
- __CASE_WEIGHT_FORMAT(OHWIo8i4_bf16)
- __CASE_WEIGHT_FORMAT(OHWIo16i4)
- __CASE_WEIGHT_FORMAT(OHWIo16i4_bf16)
- __CASE_WEIGHT_FORMAT(OHWIo32i4)
- __CASE_WEIGHT_FORMAT(OHWIo32i4_bf16)
- __CASE_WEIGHT_FORMAT(OHWIo64i4)
- __CASE_WEIGHT_FORMAT(OHWIo64i4_bf16)
- __CASE_WEIGHT_FORMAT(OHWIo2i8)
- __CASE_WEIGHT_FORMAT(OHWIo4i8)
- __CASE_WEIGHT_FORMAT(OHWIo8i8)
- __CASE_WEIGHT_FORMAT(OHWIo16i8)
- __CASE_WEIGHT_FORMAT(OHWIo32i8)
- __CASE_WEIGHT_FORMAT(OHWIo64i8)
- default:
- return "invalid value";
- }
-#undef __CASE_WEIGHT_FORMAT
-}
-
struct KernelDescription
{
GemmMethod method = GemmMethod::DEFAULT;
diff --git a/src/cpu/operators/CpuGemm.cpp b/src/cpu/operators/CpuGemm.cpp
index f3fff60..f6582c7 100644
--- a/src/cpu/operators/CpuGemm.cpp
+++ b/src/cpu/operators/CpuGemm.cpp
@@ -368,7 +368,7 @@
return _aux_mem;
}
-Status CpuGemm::has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d,
+Status CpuGemm::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d,
const GEMMInfo &gemm_info)
{
const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info);
diff --git a/src/cpu/operators/CpuGemm.h b/src/cpu/operators/CpuGemm.h
index b37ab73..8d34b22 100644
--- a/src/cpu/operators/CpuGemm.h
+++ b/src/cpu/operators/CpuGemm.h
@@ -105,15 +105,15 @@
*
* This method has the same use of @ref
* NEGEMMConvolutionLayer::has_opt_impl, with the only caveat that
- * the value of arm_gemm::WeightFormat need to be passed via the
+ * the value of arm_compute::WeightFormat need to be passed via the
* parameter gemm_info.
*/
- static Status has_opt_impl(arm_gemm::WeightFormat &weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d,
+ static Status has_opt_impl(arm_compute::WeightFormat &weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d,
const GEMMInfo &gemm_info = GEMMInfo());
// Inherited methods overridden:
- void run(ITensorPack &tensors) override;
- void prepare(ITensorPack &constants) override;
+ void run(ITensorPack &tensors) override;
+ void prepare(ITensorPack &constants) override;
experimental::MemoryRequirements workspace() const override;
/** Indicates if the convolution executes in variable weights mode.
diff --git a/src/cpu/operators/CpuGemmConv2d.cpp b/src/cpu/operators/CpuGemmConv2d.cpp
index 0174d0e..f3a16f1 100644
--- a/src/cpu/operators/CpuGemmConv2d.cpp
+++ b/src/cpu/operators/CpuGemmConv2d.cpp
@@ -62,13 +62,13 @@
const unsigned int kernel_height = weights->dimension(idx_height);
unsigned int conv_w = 0;
unsigned int conv_h = 0;
- std::tie(conv_w, conv_h) = scaled_dimensions(src->dimension(idx_width),
- src->dimension(idx_height),
- kernel_width,
- kernel_height,
- conv_info,
- dilation);
- const bool skip_im2col = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1);
+ std::tie(conv_w, conv_h) = scaled_dimensions(src->dimension(idx_width),
+ src->dimension(idx_height),
+ kernel_width,
+ kernel_height,
+ conv_info,
+ dilation);
+ const bool skip_im2col = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1);
if(skip_im2col)
{
@@ -99,7 +99,7 @@
CpuGemmConv2d::~CpuGemmConv2d() = default;
void CpuGemmConv2d::configure_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const ActivationLayerInfo &act_info,
- bool enable_fast_math, int gemm_3d_depth, bool fixed_format, arm_gemm::WeightFormat weight_format)
+ bool enable_fast_math, int gemm_3d_depth, bool fixed_format, arm_compute::WeightFormat weight_format)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(src, weights);
ARM_COMPUTE_ERROR_THROW_ON(validate_mm(src, weights, biases, dst, act_info, enable_fast_math, gemm_3d_depth, _skip_im2col, fixed_format, weight_format));
@@ -139,8 +139,8 @@
PixelValue type_min{};
PixelValue type_max{};
std::tie(type_min, type_max) = get_min_max(data_type);
- int32_t min_activation = type_min.get<int32_t>();
- int32_t max_activation = type_max.get<int32_t>();
+ int32_t min_activation = type_min.get<int32_t>();
+ int32_t max_activation = type_max.get<int32_t>();
if(supported_acts.count(act_info.activation()) != 0)
{
@@ -179,7 +179,7 @@
}
Status CpuGemmConv2d::validate_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst,
- const ActivationLayerInfo &act_info, bool enable_fast_math, int gemm_3d_depth, bool skip_im2col, bool fixed_format, arm_gemm::WeightFormat weight_format)
+ const ActivationLayerInfo &act_info, bool enable_fast_math, int gemm_3d_depth, bool skip_im2col, bool fixed_format, arm_compute::WeightFormat weight_format)
{
const DataType data_type = src->data_type();
const bool is_quantized = is_data_type_quantized_asymmetric(data_type);
@@ -203,8 +203,8 @@
PixelValue type_min{};
PixelValue type_max{};
std::tie(type_min, type_max) = get_min_max(data_type);
- int32_t min_activation = type_min.get<int32_t>();
- int32_t max_activation = type_max.get<int32_t>();
+ int32_t min_activation = type_min.get<int32_t>();
+ int32_t max_activation = type_max.get<int32_t>();
const std::set<ActivationLayerInfo::ActivationFunction> supported_acts = { ActivationLayerInfo::ActivationFunction::RELU,
ActivationLayerInfo::ActivationFunction::BOUNDED_RELU,
@@ -288,8 +288,8 @@
ITensorInfo *gemm_output_to_use = dst;
// Get convolved dimensions
- unsigned int conv_w = 0;
- unsigned int conv_h = 0;
+ unsigned int conv_w = 0;
+ unsigned int conv_h = 0;
std::tie(conv_w, conv_h) = scaled_dimensions(src->dimension(idx_width),
src->dimension(idx_height),
kernel_width,
@@ -306,8 +306,8 @@
_skip_col2im = skip_info.skip_col2im;
// Get parameters from conv_info
- unsigned int stride_x = 0;
- unsigned int stride_y = 0;
+ unsigned int stride_x = 0;
+ unsigned int stride_y = 0;
std::tie(stride_x, stride_y) = conv_info.stride();
unsigned int mat_weights_cols = weights->dimension(idx_kernels);
@@ -360,7 +360,7 @@
// Configure GEMM
// In case we need to skip col2im, GEMM3D (gemm_3d_depth != 0) must be called in order to avoid reshaping the output matrix
const unsigned int gemm_3d_depth = _skip_col2im ? conv_h : 0;
- const bool fixed_format = weights_info.weight_format() != arm_gemm::WeightFormat::UNSPECIFIED;
+ const bool fixed_format = weights_info.weight_format() != arm_compute::WeightFormat::UNSPECIFIED;
configure_mm(gemm_input_to_use, &_weights_reshaped, biases, gemm_output_to_use, act_info, enable_fast_math, gemm_3d_depth, fixed_format, weights_info.weight_format());
if(!_skip_col2im && _data_layout == DataLayout::NCHW)
@@ -388,7 +388,7 @@
_aux_mem[GemmOutput] = MemoryInfo(offset_int_vec(GemmOutput), MemoryLifetime::Temporary, _gemm_output.total_size());
}
-Status CpuGemmConv2d::has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst,
+Status CpuGemmConv2d::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst,
const PadStrideInfo &conv_info,
const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info, const bool enable_fast_math)
{
@@ -399,12 +399,12 @@
const unsigned int kernel_height = weights->dimension(idx_height);
unsigned int conv_w = 0;
unsigned int conv_h = 0;
- std::tie(conv_w, conv_h) = scaled_dimensions(src->dimension(idx_width),
- src->dimension(idx_height),
- kernel_width,
- kernel_height,
- conv_info,
- dilation);
+ std::tie(conv_w, conv_h) = scaled_dimensions(src->dimension(idx_width),
+ src->dimension(idx_height),
+ kernel_width,
+ kernel_height,
+ conv_info,
+ dilation);
const CpuGemmConv2d::SkipInfo skip_info = CpuGemmConv2d::skip_im_col_info(src, weights, conv_info,
dilation, act_info);
@@ -412,7 +412,7 @@
const bool skip_im2col = skip_info.skip_im2col;
const bool skip_col2im = skip_info.skip_col2im;
const unsigned int gemm_3d_depth = skip_col2im ? conv_h : 0;
- const bool fixed_format = weights_info.weight_format() != arm_gemm::WeightFormat::UNSPECIFIED;
+ const bool fixed_format = weights_info.weight_format() != arm_compute::WeightFormat::UNSPECIFIED;
const GEMMInfo gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */,
gemm_3d_depth, skip_im2col /* Reinterpret the input as 3D if im2col is skipped */,
false, GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info, experimental::PostOpList<ITensorInfo *>(), fixed_format, weights_info.weight_format());
@@ -464,9 +464,9 @@
dilation);
// Check if GEMM3D is supported
- const CpuGemmConv2d::SkipInfo skip_info = CpuGemmConv2d::skip_im_col_info(src, weights, conv_info,
- dilation, act_info);
- const bool skip_im2col = skip_info.skip_im2col, skip_col2im = skip_info.skip_col2im;
+ const CpuGemmConv2d::SkipInfo skip_info = CpuGemmConv2d::skip_im_col_info(src, weights, conv_info,
+ dilation, act_info);
+ const bool skip_im2col = skip_info.skip_im2col, skip_col2im = skip_info.skip_col2im;
ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(idx_channel) != src->dimension(idx_channel));
ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
@@ -527,7 +527,7 @@
}
info_gemm.set_quantization_info(dst->quantization_info()).set_data_layout(src->data_layout());
gemm_output_to_use = &info_gemm;
- const bool fixed_format = weights_info.weight_format() != arm_gemm::WeightFormat::UNSPECIFIED;
+ const bool fixed_format = weights_info.weight_format() != arm_compute::WeightFormat::UNSPECIFIED;
ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, biases, gemm_output_to_use, act_info, enable_fast_math, skip_col2im ? conv_h : 0, skip_im2col, fixed_format,
weights_info.weight_format()));
@@ -558,7 +558,7 @@
{
// Run input reshaping
unsigned int y_dim = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::HEIGHT);
- ITensorPack pack =
+ ITensorPack pack =
{
{ TensorType::ACL_SRC, src },
{ TensorType::ACL_DST, im2col_output.get() }
@@ -652,7 +652,7 @@
// Run weights reshaping and mark original weights tensor as unused
CpuAuxTensorHandler weights_reshaped(offset_int_vec(WeightsReshaped), _weights_reshaped, tensors);
auto weights = tensors.get_const_tensor(TensorType::ACL_SRC_1);
- ITensorPack pack =
+ ITensorPack pack =
{
{ TensorType::ACL_SRC, weights },
{ TensorType::ACL_DST, weights_reshaped.get() }
diff --git a/src/cpu/operators/CpuGemmConv2d.h b/src/cpu/operators/CpuGemmConv2d.h
index f8f0bce..08b76a6 100644
--- a/src/cpu/operators/CpuGemmConv2d.h
+++ b/src/cpu/operators/CpuGemmConv2d.h
@@ -123,14 +123,14 @@
*
* @return a status.
*/
- static Status has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output,
+ static Status has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output,
const PadStrideInfo &conv_info,
const WeightsInfo &weights_info = WeightsInfo(), const Size2D &dilation = Size2D(1U, 1U), const ActivationLayerInfo &act_info = ActivationLayerInfo(),
const bool enable_fast_math = false);
// Inherited methods overridden:
- void run(ITensorPack &tensors) override;
- void prepare(ITensorPack &tensors) override;
+ void run(ITensorPack &tensors) override;
+ void prepare(ITensorPack &tensors) override;
experimental::MemoryRequirements workspace() const override;
private:
@@ -150,7 +150,7 @@
* @param[in] weight_format (Optional) The layout to be used for the weights tensor when running GEMM with variable weights.
*/
void configure_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *output, const ActivationLayerInfo &act_info = ActivationLayerInfo(),
- bool enable_fast_math = false, int gemm_3d_depth = 1, bool fixed_format = false, arm_gemm::WeightFormat weight_format = arm_gemm::WeightFormat::UNSPECIFIED);
+ bool enable_fast_math = false, int gemm_3d_depth = 1, bool fixed_format = false, arm_compute::WeightFormat weight_format = arm_compute::WeightFormat::UNSPECIFIED);
/** Static function to check if given info will lead to a valid configuration of @ref NEGEMMConvolutionLayer matrix multiply routines
*
* @param[in] src Input tensor info. Data types supported: QASYMM8/QASYMM8_SIGNED/BFLOAT16/F16/F32.
@@ -170,7 +170,7 @@
* @return a status
*/
static Status validate_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const ActivationLayerInfo &act_info = ActivationLayerInfo(),
- bool enable_fast_math = false, int gemm_3d_depth = 1, bool skip_im2col = false, bool fixed_format = false, arm_gemm::WeightFormat weight_format = arm_gemm::WeightFormat::UNSPECIFIED);
+ bool enable_fast_math = false, int gemm_3d_depth = 1, bool skip_im2col = false, bool fixed_format = false, arm_compute::WeightFormat weight_format = arm_compute::WeightFormat::UNSPECIFIED);
/** Static function to check if GEMM3D is supported in @ref NEGEMM or in @ref CpuGemmMLowpMatrixMultiplyCore
*
* @param[in] src Input tensor info. Data types supported: QASYMM8/QASYMM8_SIGNED/BFLOAT16/F16/F32.
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index 558ff41..c969c9f 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -164,8 +164,8 @@
{
if(!_gemm_kernel_asm)
return false;
- const arm_gemm::WeightFormat wf = _gemm_kernel_asm->get_config().weight_format;
- return wf != arm_gemm::WeightFormat::UNSPECIFIED && wf != arm_gemm::WeightFormat::ANY;
+ const arm_compute::WeightFormat wf = assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format);
+ return wf != arm_compute::WeightFormat::UNSPECIFIED && wf != arm_compute::WeightFormat::ANY;
}
private:
@@ -428,7 +428,7 @@
if(_gemm_kernel_asm->B_pretranspose_required())
{
// Fixed format kernels need no pretranspose.
- ARM_COMPUTE_ERROR_ON(arm_gemm::is_fixed_format(_gemm_kernel_asm->get_config().weight_format));
+ ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format(assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format)));
const int ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput);
const auto in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
const int multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput);
@@ -492,8 +492,8 @@
// Check if B is pre-tranposed and de-reference if not
if(!_gemm_kernel_asm->B_is_pretransposed())
{
- ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput);
- const arm_gemm::WeightFormat wf = _gemm_kernel_asm->get_config().weight_format;
+ ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput);
+ const arm_compute::WeightFormat wf = assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format);
if(is_fixed_format(wf))
{
// The 4D tensor of dimension O'HWI' created for the
@@ -507,7 +507,7 @@
const int H = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)];
const int W = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)];
const int Ip = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)];
- const int interleave_by = arm_gemm::interleave_by(wf);
+ const int interleave_by = arm_compute::interleave_by(wf);
ldb = (interleave_by * H * W * Ip);
}
multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput);
@@ -603,7 +603,7 @@
unsigned int num_threads = NEScheduler::get().num_threads();
arm_gemm::GemmConfig cfg;
- cfg.weight_format = info.weight_format;
+ cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format);
arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fixed_format, info.fast_mode, &cfg);
// Create arm_gemm fallback
@@ -623,7 +623,7 @@
const unsigned int num_threads = NEScheduler::get().num_threads();
arm_gemm::GemmConfig cfg;
- cfg.weight_format = info.weight_format;
+ cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format);
arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fixed_format, info.fast_mode, &cfg);
// Create arm_gemm fallback
@@ -665,7 +665,7 @@
{
}
-Status CpuGemmAssemblyDispatch::has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d,
+Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d,
const AsmGemmInfo &info)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d);
@@ -675,13 +675,13 @@
const CPUInfo &ci = NEScheduler::get().cpu_info();
unsigned int num_threads = NEScheduler::get().num_threads();
arm_gemm::GemmConfig cfg;
- cfg.weight_format = info.weight_format;
-
- arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, info.fixed_format, info.fast_mode, &cfg);
+ cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format);
+ arm_gemm::WeightFormat arm_gemm_expected_wf = assembly_utils::map_to_arm_gemm_weight_format(expected_weight_format);
+ arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, info.fixed_format, info.fast_mode, &cfg);
switch(a->data_type())
{
case DataType::F32:
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float, float, arm_gemm::Nothing>(expected_weight_format, args, {})),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
"We could not find an optimized kernel for F32 input");
break;
#ifdef __aarch64__
@@ -689,12 +689,12 @@
case DataType::QASYMM8:
if(d->data_type() == DataType::S32)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint32_t, arm_gemm::Nothing>(expected_weight_format, args, {})),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
"We could not find an optimized kernel for U8/QASYMM8 input and S32 output");
}
else
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint8_t, arm_gemm::Requantize32>(expected_weight_format, args, {})),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, {})),
"We could not find an optimized kernel for U8 input and U8 output");
}
break;
@@ -702,12 +702,12 @@
case DataType::QASYMM8_SIGNED:
if(d->data_type() == DataType::S32)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int32_t, arm_gemm::Nothing>(expected_weight_format, args, {})),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
"We could not find an optimized kernel for S8/QASYMM8_SIGNED input and S32 output");
}
else
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int8_t, arm_gemm::Requantize32>(expected_weight_format, args, {})),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, {})),
"We could not find an optimized kernel for S8 input and S32 output");
}
break;
@@ -722,7 +722,7 @@
#endif /* defined(ARM_COMPUTE_ENABLE_BF16) */
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(expected_weight_format, args, {})),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
"We could not find an optimized kernel for BFLOAT16 input and F32 output");
break;
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
@@ -730,6 +730,7 @@
ARM_COMPUTE_RETURN_ERROR_ON_MSG(true, "Usupported type. Could not find a kernel");
break;
}
+ expected_weight_format = assembly_utils::map_to_arm_compute_weight_format(arm_gemm_expected_wf);
return Status{};
}
@@ -762,9 +763,9 @@
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 && d->data_type() != DataType::U32, "Only U32 output supported for U8 input");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32, "Only S32 output supported for S8 input");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 && d->data_type() != DataType::QASYMM8, "Only QASYMM8 output supported for QASYMM8 input");
- arm_gemm::WeightFormat expected_weight_format;
- const Status ret = CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a, b, c, d, info);
- if((bool)ret && expected_weight_format != arm_gemm::WeightFormat::ANY)
+ arm_compute::WeightFormat expected_weight_format;
+ const Status ret = CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a, b, c, d, info);
+ if((bool)ret && expected_weight_format != arm_compute::WeightFormat::ANY)
{
// Correctness check: if the format expected by the kernel is
// not "any", make sure that the one found matches the format
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
index 4ef108d..691eeff 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
@@ -41,19 +41,19 @@
struct AsmGemmInfo
{
- AsmConvMethod method{ AsmConvMethod::Im2Col };
- PadStrideInfo ps_info{};
- ActivationLayerInfo activation_info{};
- GEMMLowpOutputStageInfo output_stage{};
- bool negated_offsets{ true };
- bool reinterpret_input_as_3d{ false };
- bool depth_output_gemm3d{ false };
- int64_t padding_top{ 0 };
- int64_t padding_left{ 0 };
- float padding_value{ 0.f };
- bool fast_mode{ false };
- bool fixed_format{ false };
- arm_gemm::WeightFormat weight_format{ arm_gemm::WeightFormat::UNSPECIFIED };
+ AsmConvMethod method{ AsmConvMethod::Im2Col };
+ PadStrideInfo ps_info{};
+ ActivationLayerInfo activation_info{};
+ GEMMLowpOutputStageInfo output_stage{};
+ bool negated_offsets{ true };
+ bool reinterpret_input_as_3d{ false };
+ bool depth_output_gemm3d{ false };
+ int64_t padding_top{ 0 };
+ int64_t padding_left{ 0 };
+ float padding_value{ 0.f };
+ bool fast_mode{ false };
+ bool fixed_format{ false };
+ arm_compute::WeightFormat weight_format{ arm_compute::WeightFormat::UNSPECIFIED };
};
/** Assembly kernel glue */
@@ -70,12 +70,12 @@
class IFallback
{
public:
- virtual void run(ITensorPack &tensors) = 0;
- virtual void prepare(ITensorPack &tensors) = 0;
- virtual experimental::MemoryRequirements workspace() const = 0;
- virtual bool is_configured() const = 0;
- virtual bool isVarWeightsKernel() const = 0;
- virtual ~IFallback() = default;
+ virtual void run(ITensorPack &tensors) = 0;
+ virtual void prepare(ITensorPack &tensors) = 0;
+ virtual experimental::MemoryRequirements workspace() const = 0;
+ virtual bool is_configured() const = 0;
+ virtual bool isVarWeightsKernel() const = 0;
+ virtual ~IFallback() = default;
};
public:
@@ -105,12 +105,12 @@
*
* This method has the same use of @ref
* NEGEMMConvolutionLayer::has_opt_impl, with the only caveat that
- * the value of arm_gemm::WeightFormat need to be passed via the
+ * the value of arm_compute::WeightFormat need to be passed via the
* parameter info.
*
* @return a status.
*/
- static Status has_opt_impl(arm_gemm::WeightFormat &weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info);
+ static Status has_opt_impl(arm_compute::WeightFormat &weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info);
/** Checks if activation is supported by the gemm assembly dispatcher
*
* @param[in] activation Activation to check
@@ -133,8 +133,8 @@
}
// Inherited methods overridden:
- void prepare(ITensorPack &tensors) override;
- void run(ITensorPack &tensors) override;
+ void prepare(ITensorPack &tensors) override;
+ void run(ITensorPack &tensors) override;
experimental::MemoryRequirements workspace() const override;
private:
diff --git a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
index 13635c6..fe3ea6a 100644
--- a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
@@ -80,7 +80,7 @@
return cpu::CpuGemmConv2d::validate(input, weights, biases, output, conv_info, weights_info, dilation, act_info, enable_fast_math, num_groups);
}
-Status NEGEMMConvolutionLayer::has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst,
+Status NEGEMMConvolutionLayer::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst,
const PadStrideInfo &conv_info,
const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info, const bool enable_fast_math)
{
diff --git a/tests/framework/Asserts.h b/tests/framework/Asserts.h
index 5f46277..7adfa8f 100644
--- a/tests/framework/Asserts.h
+++ b/tests/framework/Asserts.h
@@ -30,6 +30,8 @@
#include <sstream>
#include <type_traits>
+#include "utils/TypePrinter.h"
+
namespace arm_compute
{
namespace test
@@ -42,9 +44,9 @@
return value;
}
-inline std::string make_printable(arm_gemm::WeightFormat wf)
+inline std::string make_printable(const arm_compute::WeightFormat wf)
{
- return arm_gemm::to_string(wf);
+ return arm_compute::to_string(wf);
}
inline unsigned int make_printable(uint8_t value)
diff --git a/tests/validation/NEON/ConvolutionLayer.cpp b/tests/validation/NEON/ConvolutionLayer.cpp
index 940983f..0194220 100644
--- a/tests/validation/NEON/ConvolutionLayer.cpp
+++ b/tests/validation/NEON/ConvolutionLayer.cpp
@@ -511,13 +511,13 @@
FIXTURE_DATA_TEST_CASE(UC2_1_CpuGemmConv2d, HasOptImplFixture<cpu::CpuGemmConv2d>, framework::DatasetMode::ALL,
combine(framework::dataset::make("DataType", { DataType::F32 }),
- framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::OHWIo2 })))
+ framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::OHWIo2 })))
{
ARM_COMPUTE_EXPECT(!_kernel_found, framework::LogLevel::ERRORS);
}
FIXTURE_DATA_TEST_CASE(UC2_1_NEGEMMConvolutionLayer, HasOptImplFixture<NEGEMMConvolutionLayer>, framework::DatasetMode::ALL,
combine(framework::dataset::make("DataType", { DataType::F32 }),
- framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::OHWIo2 })))
+ framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::OHWIo2 })))
{
ARM_COMPUTE_EXPECT(!_kernel_found, framework::LogLevel::ERRORS);
}
@@ -527,18 +527,18 @@
FIXTURE_DATA_TEST_CASE(UC2_2_CpuGemmConv2d, HasOptImplFixture<cpu::CpuGemmConv2d>, framework::DatasetMode::ALL,
combine(framework::dataset::make("DataType", { DataType::F32 }),
- framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::OHWIo4 })))
+ framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::OHWIo4 })))
{
ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(_computed_weight_format == arm_gemm::WeightFormat::OHWIo4, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(_computed_weight_format == arm_compute::WeightFormat::OHWIo4, framework::LogLevel::ERRORS);
}
FIXTURE_DATA_TEST_CASE(UC2_2_NEGEMMConvolutionLayer, HasOptImplFixture<NEGEMMConvolutionLayer>, framework::DatasetMode::ALL,
combine(framework::dataset::make("DataType", { DataType::F32 }),
- framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::OHWIo4 })))
+ framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::OHWIo4 })))
{
ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(_computed_weight_format == arm_gemm::WeightFormat::OHWIo4, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(_computed_weight_format == arm_compute::WeightFormat::OHWIo4, framework::LogLevel::ERRORS);
}
// UC3_1_* tests: the user queries for ANY fixed format, but there is
@@ -548,14 +548,14 @@
FIXTURE_DATA_TEST_CASE(UC3_1_CpuGemmConv2d, HasOptImplFixture<cpu::CpuGemmConv2d>, framework::DatasetMode::ALL,
combine(framework::dataset::make("DataType", { DataType::S32 }),
- framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::ANY })))
+ framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::ANY })))
{
ARM_COMPUTE_EXPECT(!_kernel_found, framework::LogLevel::ERRORS);
}
FIXTURE_DATA_TEST_CASE(UC3_1_NEGEMMConvolutionLayer, HasOptImplFixture<NEGEMMConvolutionLayer>, framework::DatasetMode::ALL,
combine(framework::dataset::make("DataType", { DataType::S32 }),
- framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::ANY })))
+ framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::ANY })))
{
ARM_COMPUTE_EXPECT(!_kernel_found, framework::LogLevel::ERRORS);
}
@@ -572,24 +572,24 @@
FIXTURE_DATA_TEST_CASE(UC3_2_CpuGemmConv2d, HasOptImplFixture<cpu::CpuGemmConv2d>, framework::DatasetMode::ALL,
combine(framework::dataset::make("DataType", { DataType::F32 }),
- framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::ANY })))
+ framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::ANY })))
{
ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(_computed_weight_format != arm_gemm::WeightFormat::ANY, framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(_computed_weight_format != arm_gemm::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::ANY, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS);
}
FIXTURE_DATA_TEST_CASE(UC3_2_NEGEMMConvolutionLayer, HasOptImplFixture<NEGEMMConvolutionLayer>, framework::DatasetMode::ALL,
combine(framework::dataset::make("DataType", { DataType::F32 }),
- framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::ANY })))
+ framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::ANY })))
{
- ARM_COMPUTE_EXPECT(_computed_weight_format != arm_gemm::WeightFormat::ANY, framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(_computed_weight_format != arm_gemm::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::ANY, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS);
}
namespace
{
-using TestCaseType = std::tuple<TensorShape, TensorShape, arm_gemm::WeightFormat>;
+using TestCaseType = std::tuple<TensorShape, TensorShape, arm_compute::WeightFormat>;
auto prepare_weights_shapes = framework::dataset::make("TensorShape",
{
// OHWIo<interleave_by>i<block_by>
@@ -601,51 +601,51 @@
//
// Change N for OHWIo4
- TestCaseType({ { 1U, 1U, 1U, 1U }, { 1U, 1U, 1U, 4U }, arm_gemm::WeightFormat::OHWIo4 }),
- TestCaseType({ { 1U, 1U, 1U, 2U }, { 1U, 1U, 1U, 4U }, arm_gemm::WeightFormat::OHWIo4 }),
- TestCaseType({ { 1U, 1U, 1U, 3U }, { 1U, 1U, 1U, 4U }, arm_gemm::WeightFormat::OHWIo4 }),
- TestCaseType({ { 1U, 1U, 1U, 4U }, { 1U, 1U, 1U, 4U }, arm_gemm::WeightFormat::OHWIo4 }),
- TestCaseType({ { 1U, 1U, 1U, 5U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo4 }),
- TestCaseType({ { 1U, 1U, 1U, 6U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo4 }),
- TestCaseType({ { 1U, 1U, 1U, 7U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo4 }),
- TestCaseType({ { 1U, 1U, 1U, 8U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo4 }),
- TestCaseType({ { 1U, 1U, 1U, 9U }, { 1U, 1U, 1U, 12U }, arm_gemm::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 1U, 1U, 1U, 1U }, { 1U, 1U, 1U, 4U }, arm_compute::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 1U, 1U, 1U, 2U }, { 1U, 1U, 1U, 4U }, arm_compute::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 1U, 1U, 1U, 3U }, { 1U, 1U, 1U, 4U }, arm_compute::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 1U, 1U, 1U, 4U }, { 1U, 1U, 1U, 4U }, arm_compute::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 1U, 1U, 1U, 5U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 1U, 1U, 1U, 6U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 1U, 1U, 1U, 7U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 1U, 1U, 1U, 8U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 1U, 1U, 1U, 9U }, { 1U, 1U, 1U, 12U }, arm_compute::WeightFormat::OHWIo4 }),
// // Change N for OHWIo8
- TestCaseType({ { 1U, 1U, 1U, 1U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }),
- TestCaseType({ { 1U, 1U, 1U, 2U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }),
- TestCaseType({ { 1U, 1U, 1U, 3U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }),
- TestCaseType({ { 1U, 1U, 1U, 4U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }),
- TestCaseType({ { 1U, 1U, 1U, 5U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }),
- TestCaseType({ { 1U, 1U, 1U, 6U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }),
- TestCaseType({ { 1U, 1U, 1U, 7U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }),
- TestCaseType({ { 1U, 1U, 1U, 8U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }),
- TestCaseType({ { 1U, 1U, 1U, 9U }, { 1U, 1U, 1U, 16U }, arm_gemm::WeightFormat::OHWIo8 }),
+ TestCaseType({ { 1U, 1U, 1U, 1U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo8 }),
+ TestCaseType({ { 1U, 1U, 1U, 2U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo8 }),
+ TestCaseType({ { 1U, 1U, 1U, 3U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo8 }),
+ TestCaseType({ { 1U, 1U, 1U, 4U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo8 }),
+ TestCaseType({ { 1U, 1U, 1U, 5U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo8 }),
+ TestCaseType({ { 1U, 1U, 1U, 6U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo8 }),
+ TestCaseType({ { 1U, 1U, 1U, 7U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo8 }),
+ TestCaseType({ { 1U, 1U, 1U, 8U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo8 }),
+ TestCaseType({ { 1U, 1U, 1U, 9U }, { 1U, 1U, 1U, 16U }, arm_compute::WeightFormat::OHWIo8 }),
// // Change N for OHWIo4 when H, W and C are not 1
- TestCaseType({ { 3U, 4U, 2U, 1U }, { 3, 4, 2, 4 }, arm_gemm::WeightFormat::OHWIo4 }),
- TestCaseType({ { 3U, 4U, 2U, 2U }, { 3, 4, 2, 4 }, arm_gemm::WeightFormat::OHWIo4 }),
- TestCaseType({ { 3U, 4U, 2U, 3U }, { 3, 4, 2, 4 }, arm_gemm::WeightFormat::OHWIo4 }),
- TestCaseType({ { 3U, 4U, 2U, 4U }, { 3, 4, 2, 4 }, arm_gemm::WeightFormat::OHWIo4 }),
- TestCaseType({ { 3U, 4U, 2U, 5U }, { 3, 4, 2, 8 }, arm_gemm::WeightFormat::OHWIo4 }),
- TestCaseType({ { 3U, 4U, 2U, 6U }, { 3, 4, 2, 8 }, arm_gemm::WeightFormat::OHWIo4 }),
- TestCaseType({ { 3U, 4U, 2U, 7U }, { 3, 4, 2, 8 }, arm_gemm::WeightFormat::OHWIo4 }),
- TestCaseType({ { 3U, 4U, 2U, 8U }, { 3, 4, 2, 8 }, arm_gemm::WeightFormat::OHWIo4 }),
- TestCaseType({ { 3U, 4U, 2U, 9U }, { 3, 4, 2, 12 }, arm_gemm::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 3U, 4U, 2U, 1U }, { 3, 4, 2, 4 }, arm_compute::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 3U, 4U, 2U, 2U }, { 3, 4, 2, 4 }, arm_compute::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 3U, 4U, 2U, 3U }, { 3, 4, 2, 4 }, arm_compute::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 3U, 4U, 2U, 4U }, { 3, 4, 2, 4 }, arm_compute::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 3U, 4U, 2U, 5U }, { 3, 4, 2, 8 }, arm_compute::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 3U, 4U, 2U, 6U }, { 3, 4, 2, 8 }, arm_compute::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 3U, 4U, 2U, 7U }, { 3, 4, 2, 8 }, arm_compute::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 3U, 4U, 2U, 8U }, { 3, 4, 2, 8 }, arm_compute::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 3U, 4U, 2U, 9U }, { 3, 4, 2, 12 }, arm_compute::WeightFormat::OHWIo4 }),
// // Fix N and move HWI around, with different data layouts and formats
- TestCaseType({ { 2U, 4U, 3U, 5U }, { 2, 4, 3, 8 }, arm_gemm::WeightFormat::OHWIo4 }),
- TestCaseType({ { 3U, 4U, 2U, 5U }, { 3, 4, 2, 8 }, arm_gemm::WeightFormat::OHWIo4 }),
- TestCaseType({ { 2U, 4U, 3U, 9U }, { 2, 4, 3, 16 }, arm_gemm::WeightFormat::OHWIo8 }),
- TestCaseType({ { 3U, 4U, 2U, 9U }, { 3, 4, 2, 16 }, arm_gemm::WeightFormat::OHWIo8 }),
- TestCaseType({ { 1024U, 1U, 1U, 1001U }, { 1024, 1, 1, 1008 }, arm_gemm::WeightFormat::OHWIo8 }),
+ TestCaseType({ { 2U, 4U, 3U, 5U }, { 2, 4, 3, 8 }, arm_compute::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 3U, 4U, 2U, 5U }, { 3, 4, 2, 8 }, arm_compute::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 2U, 4U, 3U, 9U }, { 2, 4, 3, 16 }, arm_compute::WeightFormat::OHWIo8 }),
+ TestCaseType({ { 3U, 4U, 2U, 9U }, { 3, 4, 2, 16 }, arm_compute::WeightFormat::OHWIo8 }),
+ TestCaseType({ { 1024U, 1U, 1U, 1001U }, { 1024, 1, 1, 1008 }, arm_compute::WeightFormat::OHWIo8 }),
// // Adding <block_by> on I (=C)
- TestCaseType({ { 1U, 4U, 3U, 5U }, { 2, 4, 3, 8 }, arm_gemm::WeightFormat::OHWIo4i2 }),
- TestCaseType({ { 2U, 4U, 3U, 5U }, { 2, 4, 3, 8 }, arm_gemm::WeightFormat::OHWIo4i2 }),
- TestCaseType({ { 3U, 4U, 3U, 5U }, { 4, 4, 3, 8 }, arm_gemm::WeightFormat::OHWIo4i2 }),
+ TestCaseType({ { 1U, 4U, 3U, 5U }, { 2, 4, 3, 8 }, arm_compute::WeightFormat::OHWIo4i2 }),
+ TestCaseType({ { 2U, 4U, 3U, 5U }, { 2, 4, 3, 8 }, arm_compute::WeightFormat::OHWIo4i2 }),
+ TestCaseType({ { 3U, 4U, 3U, 5U }, { 4, 4, 3, 8 }, arm_compute::WeightFormat::OHWIo4i2 }),
// ---------
- TestCaseType({ { 2, 2, 1, 5 }, { 2, 2, 1, 8 }, arm_gemm::WeightFormat::OHWIo4 }),
- TestCaseType({ { 1, 2, 2, 5 }, { 1, 2, 2, 8 }, arm_gemm::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 2, 2, 1, 5 }, { 2, 2, 1, 8 }, arm_compute::WeightFormat::OHWIo4 }),
+ TestCaseType({ { 1, 2, 2, 5 }, { 1, 2, 2, 8 }, arm_compute::WeightFormat::OHWIo4 }),
});
} // unnamed namespace
@@ -653,14 +653,14 @@
DATA_TEST_CASE(PrepareWeightShape, framework::DatasetMode::ALL,
prepare_weights_shapes, shapes)
{
- const TensorShape input_shape = std::get<0>(shapes);
- const TensorShape expected_shape = std::get<1>(shapes);
- const arm_gemm::WeightFormat wf = std::get<2>(shapes);
- const DataType DT = DataType::F32;
- const DataLayout DL = DataLayout::NHWC;
- const auto TI = TensorInfo(input_shape, 1 /*num_channels, deprecated*/, DT, DL);
- const TensorInfo computed = ::arm_compute::test::validation::prepare_weights(TI, wf);
- const TensorInfo expected = TensorInfo(expected_shape, 1 /*num_channels, deprecated*/, DT, DL);
+ const TensorShape input_shape = std::get<0>(shapes);
+ const TensorShape expected_shape = std::get<1>(shapes);
+ const arm_compute::WeightFormat wf = std::get<2>(shapes);
+ const DataType DT = DataType::F32;
+ const DataLayout DL = DataLayout::NHWC;
+ const auto TI = TensorInfo(input_shape, 1 /*num_channels, deprecated*/, DT, DL);
+ const TensorInfo computed = ::arm_compute::test::validation::prepare_weights(TI, wf);
+ const TensorInfo expected = TensorInfo(expected_shape, 1 /*num_channels, deprecated*/, DT, DL);
ARM_COMPUTE_EXPECT_EQUAL(computed, expected, framework::LogLevel::ERRORS);
}
diff --git a/tests/validation/fixtures/ConvolutionLayerFixture.h b/tests/validation/fixtures/ConvolutionLayerFixture.h
index d3804ee..c58a0a2 100644
--- a/tests/validation/fixtures/ConvolutionLayerFixture.h
+++ b/tests/validation/fixtures/ConvolutionLayerFixture.h
@@ -122,14 +122,14 @@
{
case DataType::QASYMM8:
{
- std::pair<int, int> bounds = get_quantized_bounds(tensor.quantization_info(), -1.0f, 1.0f);
+ std::pair<int, int> bounds = get_quantized_bounds(tensor.quantization_info(), -1.0f, 1.0f);
std::uniform_int_distribution<uint32_t> distribution(bounds.first, bounds.second);
library->fill(tensor, distribution, i);
break;
}
case DataType::QASYMM8_SIGNED:
{
- std::pair<int, int> bounds = get_quantized_qasymm8_signed_bounds(tensor.quantization_info(), -1.0f, 1.0f);
+ std::pair<int, int> bounds = get_quantized_qasymm8_signed_bounds(tensor.quantization_info(), -1.0f, 1.0f);
std::uniform_int_distribution<int32_t> distribution(bounds.first, bounds.second);
library->fill(tensor, distribution, i);
break;
@@ -400,7 +400,7 @@
};
#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
-inline TensorInfo prepare_weights(const TensorInfo tensor_info, const arm_gemm::WeightFormat weight_format)
+inline TensorInfo prepare_weights(const TensorInfo tensor_info, const arm_compute::WeightFormat weight_format)
{
const DataLayout data_layout = tensor_info.data_layout();
ARM_COMPUTE_EXPECT(data_layout == DataLayout::NHWC, framework::LogLevel::ERRORS);
@@ -411,8 +411,8 @@
const int W = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)];
const int C = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)]; // C=I
- const int interleave_by = arm_gemm::interleave_by(weight_format);
- const int block_by = arm_gemm::block_by(weight_format);
+ const int interleave_by = arm_compute::interleave_by(weight_format);
+ const int block_by = arm_compute::block_by(weight_format);
const int Ip = arm_gemm::roundup<unsigned int>(C, block_by); // C'=I'
const int Op = arm_gemm::roundup<unsigned int>(N, interleave_by); // O'=N'
@@ -421,12 +421,12 @@
}
template <typename ScalarType, typename AccessorType>
-inline void rearrange_data(const AccessorType src, AccessorType dst, const arm_gemm::WeightFormat weight_format)
+inline void rearrange_data(const AccessorType src, AccessorType dst, const arm_compute::WeightFormat weight_format)
{
- ARM_COMPUTE_EXPECT(arm_gemm::is_fixed_format(weight_format), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(arm_compute::is_fixed_format(weight_format), framework::LogLevel::ERRORS);
// Data Layout: OHWIo<interleave_by>i<block_by>
- const int interleave_by = arm_gemm::interleave_by(weight_format);
- const int block_by = arm_gemm::block_by(weight_format);
+ const int interleave_by = arm_compute::interleave_by(weight_format);
+ const int block_by = arm_compute::block_by(weight_format);
const TensorShape src_tensor_shape = src.shape();
const DataLayout data_layout = src.data_layout();
ARM_COMPUTE_EXPECT(data_layout == DataLayout::NHWC, framework::LogLevel::ERRORS);
@@ -545,12 +545,12 @@
const int kernel_width = weights_shape[get_data_layout_dimension_index(_data_layout, DataLayoutDimension::WIDTH)];
const int num_kernels = weights_shape[get_data_layout_dimension_index(_data_layout, DataLayoutDimension::BATCHES)];
- const WeightsInfo query_weights_info(/*reshape_weights*/ false, kernel_width, kernel_height, num_kernels, false, arm_gemm::WeightFormat::ANY);
+ const WeightsInfo query_weights_info(/*reshape_weights*/ false, kernel_width, kernel_height, num_kernels, false, arm_compute::WeightFormat::ANY);
const bool kernel_found = bool(ConvolutionFunction::has_opt_impl(_computed_weight_format, &src_tensor_info, &weight_tensor_info,
&bias_tensor_info, &dst_tensor_info, conv_info, query_weights_info));
// Make surethat the setup founds a fixed-format kernel as requested by the test case.
ARM_COMPUTE_EXPECT(kernel_found, framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(arm_gemm::is_fixed_format(_computed_weight_format), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(arm_compute::is_fixed_format(_computed_weight_format), framework::LogLevel::ERRORS);
const WeightsInfo weights_info(/*reshape_weights*/ false, kernel_width, kernel_height, num_kernels, false, _computed_weight_format);
configure_and_execute_kernel(src_tensor_info, weight_tensor_info, bias_tensor_info, dst_tensor_info, weights_info, conv_info,
@@ -576,7 +576,7 @@
protected:
std::unique_ptr<ConvolutionFunction> conv{};
- arm_gemm::WeightFormat _computed_weight_format{ arm_gemm::WeightFormat::UNSPECIFIED };
+ arm_compute::WeightFormat _computed_weight_format{ arm_compute::WeightFormat::UNSPECIFIED };
TensorClass _target{};
SimpleTensor<ScalarType> _reference{};
};
@@ -669,7 +669,7 @@
{
public:
template <typename...>
- void setup(DataType data_type, arm_gemm::WeightFormat query_weight_format)
+ void setup(DataType data_type, arm_compute::WeightFormat query_weight_format)
{
auto conv = std::make_unique<ConvolutionClass>();
const auto src_info = TensorInfo(TensorShape(1U, 5U, 2U), 1, data_type, DataLayout::NHWC);
@@ -683,8 +683,8 @@
}
protected:
- bool _kernel_found{ false };
- arm_gemm::WeightFormat _computed_weight_format{ arm_gemm::WeightFormat::UNSPECIFIED };
+ bool _kernel_found{ false };
+ arm_compute::WeightFormat _computed_weight_format{ arm_compute::WeightFormat::UNSPECIFIED };
};
#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
diff --git a/utils/TypePrinter.h b/utils/TypePrinter.h
index 23e73f6..f47943a 100644
--- a/utils/TypePrinter.h
+++ b/utils/TypePrinter.h
@@ -473,14 +473,14 @@
}
#if defined(ARM_COMPUTE_ENABLE_BF16)
-inline ::std::ostream &operator<<(::std::ostream &os, const bfloat16& v)
+inline ::std::ostream &operator<<(::std::ostream &os, const bfloat16 &v)
{
std::stringstream str;
str << v;
os << str.str();
return os;
}
-#endif /* defined(ARM_COMPUTE_ENABLE_BF16) */
+#endif /* defined(ARM_COMPUTE_ENABLE_BF16) */
/** Formatted output of the BoundingBoxTransformInfo type.
*
@@ -3252,19 +3252,81 @@
return str.str();
}
-inline ::std::ostream &operator<<(::std::ostream &os, const arm_gemm::WeightFormat &wf)
+/** Formatted output of the arm_compute::WeightFormat type.
+ *
+ * @param[in] wf arm_compute::WeightFormat Type to output.
+ *
+ * @return Formatted string.
+ */
+inline std::string to_string(const WeightFormat wf)
{
- os << arm_gemm::to_string(wf);
- return os;
-}
-inline std::string to_string(const arm_gemm::WeightFormat wf)
-{
- std::stringstream str;
- str << wf;
- return str.str();
+#define __CASE_WEIGHT_FORMAT(wf) \
+case WeightFormat::wf: \
+ return #wf;
+ switch(wf)
+ {
+ __CASE_WEIGHT_FORMAT(UNSPECIFIED)
+ __CASE_WEIGHT_FORMAT(ANY)
+ __CASE_WEIGHT_FORMAT(OHWI)
+ __CASE_WEIGHT_FORMAT(OHWIo2)
+ __CASE_WEIGHT_FORMAT(OHWIo4)
+ __CASE_WEIGHT_FORMAT(OHWIo8)
+ __CASE_WEIGHT_FORMAT(OHWIo16)
+ __CASE_WEIGHT_FORMAT(OHWIo32)
+ __CASE_WEIGHT_FORMAT(OHWIo64)
+ __CASE_WEIGHT_FORMAT(OHWIo128)
+ __CASE_WEIGHT_FORMAT(OHWIo4i2)
+ __CASE_WEIGHT_FORMAT(OHWIo4i2_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo8i2)
+ __CASE_WEIGHT_FORMAT(OHWIo8i2_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo16i2)
+ __CASE_WEIGHT_FORMAT(OHWIo16i2_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo32i2)
+ __CASE_WEIGHT_FORMAT(OHWIo32i2_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo64i2)
+ __CASE_WEIGHT_FORMAT(OHWIo64i2_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo4i4)
+ __CASE_WEIGHT_FORMAT(OHWIo4i4_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo8i4)
+ __CASE_WEIGHT_FORMAT(OHWIo8i4_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo16i4)
+ __CASE_WEIGHT_FORMAT(OHWIo16i4_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo32i4)
+ __CASE_WEIGHT_FORMAT(OHWIo32i4_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo64i4)
+ __CASE_WEIGHT_FORMAT(OHWIo64i4_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo2i8)
+ __CASE_WEIGHT_FORMAT(OHWIo4i8)
+ __CASE_WEIGHT_FORMAT(OHWIo8i8)
+ __CASE_WEIGHT_FORMAT(OHWIo16i8)
+ __CASE_WEIGHT_FORMAT(OHWIo32i8)
+ __CASE_WEIGHT_FORMAT(OHWIo64i8)
+ default:
+ return "invalid value";
+ }
+#undef __CASE_WEIGHT_FORMAT
}
-inline std::string to_string(const std::tuple<TensorShape, TensorShape, arm_gemm::WeightFormat> values)
+/** Formatted output of the arm_compute::WeightFormat type.
+ *
+ * @param[out] os Output stream.
+ * @param[in] wf WeightFormat to output.
+ *
+ * @return Modified output stream.
+ */
+inline ::std::ostream &operator<<(::std::ostream &os, const arm_compute::WeightFormat &wf)
+{
+ os << to_string(wf);
+ return os;
+}
+
+/** Formatted output of the std::tuple<TensorShape, TensorShape, arm_compute::WeightFormat> tuple.
+ *
+ * @param[in] values tuple of input and output tensor shapes and WeightFormat used.
+ *
+ * @return Formatted string.
+ */
+inline std::string to_string(const std::tuple<TensorShape, TensorShape, arm_compute::WeightFormat> values)
{
std::stringstream str;
str << "[Input shape = " << std::get<0>(values);