DirectConv3d support refine
- Decouple data support of CpuDirectConv3dKernel
- Update documentation for Conv3d
Signed-off-by: Sheri Zhang <sheri.zhang@arm.com>
Change-Id: I1d94aa28f821f45a1a3d39cc3335c8faeee89f0d
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/6453
Reviewed-by: Giorgio Arena <giorgio.arena@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/cpu/operators/CpuDirectConv3d.cpp b/src/cpu/operators/CpuDirectConv3d.cpp
index 3827910..aa74e42 100644
--- a/src/cpu/operators/CpuDirectConv3d.cpp
+++ b/src/cpu/operators/CpuDirectConv3d.cpp
@@ -40,10 +40,10 @@
{
}
-void CpuDirectConv3d::configure(ITensorInfo *src, ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const Conv3dInfo conv_info)
+void CpuDirectConv3d::configure(ITensorInfo *src0, ITensorInfo *src1, const ITensorInfo *src2, ITensorInfo *dst, const Conv3dInfo conv_info)
{
- ARM_COMPUTE_LOG_PARAMS(src, weights, biases, dst, conv_info);
- ARM_COMPUTE_ERROR_ON(src->data_layout() != DataLayout::NDHWC);
+ ARM_COMPUTE_LOG_PARAMS(src0, src1, src2, dst, conv_info);
+ ARM_COMPUTE_ERROR_ON(src0->data_layout() != DataLayout::NDHWC);
_conv_kernel = std::make_unique<kernels::CpuDirectConv3dKernel>();
@@ -55,7 +55,7 @@
_dim_split = Window::DimY;
- _conv_kernel->configure(src, weights, biases, dst, conv_info);
+ _conv_kernel->configure(src0, src1, src2, dst, conv_info);
//Configure Activation Layer
_is_activationlayer_enabled = conv_info.act_info.enabled();
@@ -66,16 +66,12 @@
}
}
-Status CpuDirectConv3d::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv3dInfo conv_info)
+Status CpuDirectConv3d::validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const Conv3dInfo conv_info)
{
- ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, weights, dst);
-
- // output might not be initialized since it can be an intermediate tensor of another layer
- DataType data_type = src->data_type();
- TensorInfo accumulator(dst->clone()->set_is_resizable(true).reset_padding().set_data_type(data_type));
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src0, src1, dst);
// Validate Convolution kernel
- ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuDirectConv3dKernel::validate(src, weights, biases, &accumulator, conv_info));
+ ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuDirectConv3dKernel::validate(src0, src1, src2, dst, conv_info));
if(conv_info.act_info.enabled())
{
diff --git a/src/cpu/operators/CpuDirectConv3d.h b/src/cpu/operators/CpuDirectConv3d.h
index ad04dee..f7c3099 100644
--- a/src/cpu/operators/CpuDirectConv3d.h
+++ b/src/cpu/operators/CpuDirectConv3d.h
@@ -57,23 +57,31 @@
~CpuDirectConv3d();
/** Set the input, weights, biases and output tensor info.
*
- * @param[in, out] src Input tensor info.
- * @param[in] weights Set of kernels to convolve the input volume.
- * The 2nd dimension must be the same as the input's volume 1st dimension.
- * Data type supported: Same as @p src.
- * @param[in] biases Set of biases. Can be nullptr. Data type supported: Same as @p src.
+ * Valid data layouts:
+ * - NDHWC
+ *
+ * Valid data type configurations:
+ * |src0 |src1 |src2 |dst |
+ * |:--------------|:------------------|:------|:--------------|
+ * |F16 |F16 |F16 |F16 |
+ * |F32 |F32 |F32 |F32 |
+ *
+ * @param[in, out] src0 Input tensor info.
+ * @param[in] src1 Set of kernels to convolve the input volume.
+ * The 2nd dimension must be the same as the src0's volume 1st dimension.
+ * @param[in] src2 Set of biases. Can be nullptr.
* @param[out] dst Output tensor info.
* The 1st dimensions must be equal to the 1st dimension of the @p kernels tensor.
* @param[in] conv_info Contains padding, stride, acitvation information.
*/
- void configure(ITensorInfo *src, ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const Conv3dInfo conv_info);
+ void configure(ITensorInfo *src0, ITensorInfo *src1, const ITensorInfo *src2, ITensorInfo *dst, const Conv3dInfo conv_info);
/** Static function to check if given info will lead to a valid configuration
*
* Similar to CpuDirectConv3d::configure()
*
* @return a status
*/
- static Status validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv3dInfo conv_info);
+ static Status validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const Conv3dInfo conv_info);
// Inherited methods overridden:
void run(ITensorPack &tensors) override;