COMPMID-3280: Make all ML primitives for CL use the new interface - Part 1
- Only CLKernels have been updated
Change-Id: Ife55b847c2e39e712a186eb6ca452503d5b66937
Signed-off-by: Manuel Bottini <manuel.bottini@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3001
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/arm_compute/core/CL/kernels/CLFuseBatchNormalizationKernel.h b/arm_compute/core/CL/kernels/CLFuseBatchNormalizationKernel.h
index aa60376..2d62a57 100644
--- a/arm_compute/core/CL/kernels/CLFuseBatchNormalizationKernel.h
+++ b/arm_compute/core/CL/kernels/CLFuseBatchNormalizationKernel.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 ARM Limited.
+ * Copyright (c) 2018-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -65,6 +65,25 @@
void configure(const ICLTensor *input_weights, const ICLTensor *bn_mean, const ICLTensor *bn_var, ICLTensor *fused_weights, ICLTensor *fused_bias,
const ICLTensor *input_bias = nullptr, const ICLTensor *bn_beta = nullptr, const ICLTensor *bn_gamma = nullptr,
float epsilon = 0.001f, FuseBatchNormalizationType fbn_type = FuseBatchNormalizationType::CONVOLUTION);
+ /** Set the source, destination of the kernel
+ *
+ * @param[in] compile_context The compile context to be used.
+ * @param[in] input_weights Input weights tensor for convolution or depthwise convolution layer. Data type supported: F16/F32. Data layout supported: NCHW, NHWC
+ * @param[in] bn_mean Batch normalization layer mean tensor. Same as @p input_weights
+ * @param[in] bn_var Batch normalization layer variance tensor. Same as @p input_weights
+ * @param[out] fused_weights Output fused weights tensor. It can be a nullptr in case of in-place computation. Same as @p input_weights
+ * @param[out] fused_bias Output fused bias tensor. It can be a nullptr in case of in-place computation and input_bias != nullptr. Same as @p input_weights
+ * @param[in] input_bias (Optional) Input bias tensor for convolution or depthwise convolution layer. It can be a nullptr in case the bias tensor is not required. Same as @p input_weights
+ * @param[in] bn_beta (Optional) Batch normalization layer beta tensor. It can be a nullptr in case the beta tensor is not required. Same as @p input_weights
+ * @note if nullptr, bn_beta is set to 0.0
+ * @param[in] bn_gamma (Optional) Batch normalization layer gamma tensor. It can be a nullptr in case the gamma tensor is not required. Same as @p input_weights
+ * @note if nullptr, bn_gamma is set to 1.0
+ * @param[in] epsilon (Optional) Batch normalization layer epsilon parameter. Defaults to 0.001f.
+ * @param[in] fbn_type (Optional) Fused batch normalization type. Defaults to CONVOLUTION.
+ */
+ void configure(CLCompileContext &compile_context, const ICLTensor *input_weights, const ICLTensor *bn_mean, const ICLTensor *bn_var, ICLTensor *fused_weights, ICLTensor *fused_bias,
+ const ICLTensor *input_bias = nullptr, const ICLTensor *bn_beta = nullptr, const ICLTensor *bn_gamma = nullptr,
+ float epsilon = 0.001f, FuseBatchNormalizationType fbn_type = FuseBatchNormalizationType::CONVOLUTION);
/** Static function to check if given info will lead to a valid configuration of @ref CLFuseBatchNormalizationKernel
*
* @param[in] input_weights Input weights tensor info for convolution or depthwise convolution layer. Data type supported: F16/F32. Data layout supported: NCHW, NHWC