diff --git a/src/core/CL/kernels/CLSpaceToBatchLayerKernel.cpp b/src/core/CL/kernels/CLSpaceToBatchLayerKernel.cpp
index cda6e96..9e4010e 100644
--- a/src/core/CL/kernels/CLSpaceToBatchLayerKernel.cpp
+++ b/src/core/CL/kernels/CLSpaceToBatchLayerKernel.cpp
@@ -58,11 +58,16 @@
     // Validate output if initialized
     if(output->total_size() != 0)
     {
-        ARM_COMPUTE_RETURN_ERROR_ON(output->tensor_shape()[0] < padding_left.x() + padding_right.y());
-        ARM_COMPUTE_RETURN_ERROR_ON(input->tensor_shape()[0] / block_shape_x != (output->tensor_shape()[0] - padding_left.x() - padding_right.y()));
-        ARM_COMPUTE_RETURN_ERROR_ON(input->tensor_shape()[1] / block_shape_y != (output->tensor_shape()[1] - padding_left.x() - padding_right.y()));
-        ARM_COMPUTE_RETURN_ERROR_ON(input->tensor_shape()[2] != output->tensor_shape()[2]);
-        ARM_COMPUTE_RETURN_ERROR_ON(output->tensor_shape()[3] % (block_shape_x * block_shape_y) != 0);
+        const DataLayout data_layout = input->data_layout();
+        const int        idx_width   = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+        const int        idx_height  = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+        const int        idx_channel = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
+        const int        idx_batch   = get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES);
+        ARM_COMPUTE_RETURN_ERROR_ON(output->tensor_shape()[idx_width] < padding_left.x() + padding_right.y());
+        ARM_COMPUTE_RETURN_ERROR_ON(input->tensor_shape()[idx_width] / block_shape_x != (output->tensor_shape()[idx_width] - padding_left.x() - padding_right.y()));
+        ARM_COMPUTE_RETURN_ERROR_ON(input->tensor_shape()[idx_height] / block_shape_y != (output->tensor_shape()[idx_height] - padding_left.x() - padding_right.y()));
+        ARM_COMPUTE_RETURN_ERROR_ON(input->tensor_shape()[idx_channel] != output->tensor_shape()[idx_channel]);
+        ARM_COMPUTE_RETURN_ERROR_ON(output->tensor_shape()[idx_batch] % (block_shape_x * block_shape_y) != 0);
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
     }
 
@@ -85,13 +90,18 @@
     _paddings    = paddings;
     _output      = output;
 
+    const DataLayout data_layout = input->info()->data_layout();
+    const int        idx_width   = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+    const int        idx_height  = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+    const int        idx_batch   = get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES);
+
     // Create kernel
     CLBuildOptions build_opts;
     build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type()));
-    build_opts.add_option("-DWIDTH_OUT=" + support::cpp11::to_string(output->info()->dimension(0)));
-    build_opts.add_option("-DHEIGHT_OUT=" + support::cpp11::to_string(output->info()->dimension(1)));
-    build_opts.add_option("-DBATCH_SIZE=" + support::cpp11::to_string(output->info()->dimension(3)));
-    _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("space_to_batch", build_opts.options()));
+    build_opts.add_option("-DWIDTH_OUT=" + support::cpp11::to_string(output->info()->dimension(idx_width)));
+    build_opts.add_option("-DHEIGHT_OUT=" + support::cpp11::to_string(output->info()->dimension(idx_height)));
+    build_opts.add_option("-DBATCH_SIZE=" + support::cpp11::to_string(output->info()->dimension(idx_batch)));
+    _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("space_to_batch_" + lower_string(string_from_data_layout(input->info()->data_layout())), build_opts.options()));
 
     // Configure kernel window
     Window win = calculate_max_window(*output->info(), Steps());
@@ -111,19 +121,24 @@
     _input  = input;
     _output = output;
 
+    const DataLayout data_layout = input->info()->data_layout();
+    const int        idx_width   = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+    const int        idx_height  = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+    const int        idx_batch   = get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES);
+
     // Create kernel
     CLBuildOptions build_opts;
     build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type()));
-    build_opts.add_option("-DWIDTH_OUT=" + support::cpp11::to_string(output->info()->dimension(0)));
-    build_opts.add_option("-DHEIGHT_OUT=" + support::cpp11::to_string(output->info()->dimension(1)));
-    build_opts.add_option("-DBATCH_SIZE=" + support::cpp11::to_string(output->info()->dimension(3)));
+    build_opts.add_option("-DWIDTH_OUT=" + support::cpp11::to_string(output->info()->dimension(idx_width)));
+    build_opts.add_option("-DHEIGHT_OUT=" + support::cpp11::to_string(output->info()->dimension(idx_height)));
+    build_opts.add_option("-DBATCH_SIZE=" + support::cpp11::to_string(output->info()->dimension(idx_batch)));
     build_opts.add_option("-DBLOCK_SHAPE_X=" + support::cpp11::to_string(block_shape_x));
     build_opts.add_option("-DBLOCK_SHAPE_Y=" + support::cpp11::to_string(block_shape_y));
     build_opts.add_option("-DPAD_START_X=" + support::cpp11::to_string(padding_left.x()));
     build_opts.add_option("-DPAD_END_X=" + support::cpp11::to_string(padding_right.x()));
     build_opts.add_option("-DPAD_START_Y=" + support::cpp11::to_string(padding_left.y()));
     build_opts.add_option("-DPAD_END_Y=" + support::cpp11::to_string(padding_right.y()));
-    _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("space_to_batch_static", build_opts.options()));
+    _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("space_to_batch_static_" + lower_string(string_from_data_layout(input->info()->data_layout())), build_opts.options()));
 
     // Configure kernel window
     Window win = calculate_max_window(*output->info(), Steps());
