COMPMID-1017: Implement dilated convolution in NEON, OpenCL, and GC
Change-Id: If4626ec9e215e14dffe22e80812da5bac84a52e2
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/125734
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Tested-by: Jenkins <bsgcomp@arm.com>
diff --git a/src/core/Utils.cpp b/src/core/Utils.cpp
index f4b4553..4a237f9 100644
--- a/src/core/Utils.cpp
+++ b/src/core/Utils.cpp
@@ -292,7 +292,8 @@
const std::pair<unsigned int, unsigned int> arm_compute::scaled_dimensions(unsigned int width, unsigned int height,
unsigned int kernel_width, unsigned int kernel_height,
- const PadStrideInfo &pad_stride_info)
+ const PadStrideInfo &pad_stride_info,
+ const Size2D &dilation)
{
const unsigned int pad_left = pad_stride_info.pad_left();
const unsigned int pad_top = pad_stride_info.pad_top();
@@ -305,12 +306,12 @@
switch(pad_stride_info.round())
{
case DimensionRoundingType::FLOOR:
- w = static_cast<unsigned int>(std::floor((static_cast<float>(width + pad_left + pad_right - kernel_width) / stride_x) + 1));
- h = static_cast<unsigned int>(std::floor((static_cast<float>(height + pad_top + pad_bottom - kernel_height) / stride_y) + 1));
+ w = static_cast<unsigned int>(std::floor((static_cast<float>(width + pad_left + pad_right - (dilation.x() * (kernel_width - 1) + 1)) / stride_x) + 1));
+ h = static_cast<unsigned int>(std::floor((static_cast<float>(height + pad_top + pad_bottom - (dilation.y() * (kernel_height - 1) + 1)) / stride_y) + 1));
break;
case DimensionRoundingType::CEIL:
- w = static_cast<unsigned int>(std::ceil((static_cast<float>(width + pad_left + pad_right - kernel_width) / stride_x) + 1));
- h = static_cast<unsigned int>(std::ceil((static_cast<float>(height + pad_top + pad_bottom - kernel_height) / stride_y) + 1));
+ w = static_cast<unsigned int>(std::ceil((static_cast<float>(width + pad_left + pad_right - (dilation.x() * (kernel_width - 1) + 1)) / stride_x) + 1));
+ h = static_cast<unsigned int>(std::ceil((static_cast<float>(height + pad_top + pad_bottom - (dilation.y() * (kernel_height - 1) + 1)) / stride_y) + 1));
break;
default:
ARM_COMPUTE_ERROR("Unsupported rounding type");