COMPIMID-553: MobileNet use case.

Change-Id: I1181abbd5785065f3d57e91844376a4b110938a9
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/110701
Tested-by: BSG Visual Compute Jenkins server to access repositories on http://mpd-gerrit.cambridge.arm.com <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/arm_compute/graph/Nodes.h b/arm_compute/graph/Nodes.h
index 79407f9..0282e1d 100644
--- a/arm_compute/graph/Nodes.h
+++ b/arm_compute/graph/Nodes.h
@@ -29,6 +29,7 @@
 #include "arm_compute/graph/nodes/BranchLayer.h"
 #include "arm_compute/graph/nodes/ConvolutionLayer.h"
 #include "arm_compute/graph/nodes/DepthConvertLayer.h"
+#include "arm_compute/graph/nodes/DepthwiseConvolutionLayer.h"
 #include "arm_compute/graph/nodes/DequantizationLayer.h"
 #include "arm_compute/graph/nodes/FlattenLayer.h"
 #include "arm_compute/graph/nodes/FloorLayer.h"
diff --git a/arm_compute/graph/nodes/BranchLayer.h b/arm_compute/graph/nodes/BranchLayer.h
index 3d13f5f..dd05315 100644
--- a/arm_compute/graph/nodes/BranchLayer.h
+++ b/arm_compute/graph/nodes/BranchLayer.h
@@ -64,6 +64,17 @@
         },
         std::move(rest_sub_graphs)...);
     }
+    /** Default Constructor
+     *
+     * @param[in] sub_graph Sub graph
+     */
+    template <typename... Ts>
+    BranchLayer(SubGraph &&sub_graph)
+        : _branch_merge_method(BranchMergeMethod::DEPTH_CONCATENATE), _sub_graphs()
+    {
+        /* TODO:(geopin01) Use traits to make sure variadic arguments are of SubGraph type */
+        _sub_graphs.push_back(arm_compute::support::cpp14::make_unique<SubGraph>(std::move(sub_graph)));
+    }
 
     // Inherited methods overriden:
     std::unique_ptr<arm_compute::IFunction> instantiate_node(GraphContext &ctx, ITensorObject *input, ITensorObject *output) override;
diff --git a/arm_compute/graph/nodes/DepthwiseConvolutionLayer.h b/arm_compute/graph/nodes/DepthwiseConvolutionLayer.h
index 48b2ef9..8b7e3b8 100644
--- a/arm_compute/graph/nodes/DepthwiseConvolutionLayer.h
+++ b/arm_compute/graph/nodes/DepthwiseConvolutionLayer.h
@@ -47,13 +47,13 @@
      * @param[in] conv_width  Convolution width
      * @param[in] conv_height Convolution height
      * @param[in] weights     Weights values tensor
+     * @param[in] biases      Biases values tensor
      * @param[in] conv_info   Convolution info
-     * @param[in] biases      (Optional) Biases values tensor
      * @param[in] opt3x3      (Optional) If true executes DepthwiseConvolutionLayer3x3
      */
     template <typename AccessorType>
-    DepthwiseConvolutionLayer(unsigned int conv_width, unsigned int conv_height, AccessorType &&weights, const PadStrideInfo conv_info, AccessorType &&biases = nullptr, bool opt3x3 = true)
-        : _conv_width(conv_width), _conv_height(conv_height), _weights(std::move(weights)), _conv_info(conv_info), _biases(std::move(biases)), _opt3x3(opt3x3)
+    DepthwiseConvolutionLayer(unsigned int conv_width, unsigned int conv_height, AccessorType &&weights, AccessorType &&biases, const PadStrideInfo conv_info, bool opt3x3 = true)
+        : _conv_width(conv_width), _conv_height(conv_height), _weights(std::move(weights)), _biases(std::move(biases)), _conv_info(conv_info), _opt3x3(opt3x3)
     {
     }
 
@@ -64,8 +64,8 @@
     unsigned int        _conv_width;
     unsigned int        _conv_height;
     Tensor              _weights;
-    const PadStrideInfo _conv_info;
     Tensor              _biases;
+    const PadStrideInfo _conv_info;
     bool                _opt3x3;
 };
 } // namespace graph
diff --git a/examples/graph_mobilenet.cpp b/examples/graph_mobilenet.cpp
new file mode 100644
index 0000000..2b2da9e
--- /dev/null
+++ b/examples/graph_mobilenet.cpp
@@ -0,0 +1,170 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "arm_compute/graph/Graph.h"
+#include "arm_compute/graph/Nodes.h"
+#include "support/ToolchainSupport.h"
+#include "utils/GraphUtils.h"
+#include "utils/Utils.h"
+
+#include <cstdlib>
+
+using namespace arm_compute::graph;
+using namespace arm_compute::graph_utils;
+
+BranchLayer get_dwsc_node(const std::string &data_path, std::string &&param_path,
+                          unsigned int  conv_filt,
+                          PadStrideInfo dwc_pad_stride_info, PadStrideInfo conv_pad_stride_info)
+{
+    std::string total_path = "/cnn_data/mobilenet_v1_model/" + param_path + "_";
+    SubGraph    sg;
+    sg << DepthwiseConvolutionLayer(
+           3U, 3U,
+           get_weights_accessor(data_path, total_path + "depthwise_depthwise_weights.npy"),
+           std::unique_ptr<arm_compute::graph::ITensorAccessor>(nullptr),
+           dwc_pad_stride_info,
+           true)
+       << BatchNormalizationLayer(
+           get_weights_accessor(data_path, total_path + "depthwise_BatchNorm_moving_mean.npy"),
+           get_weights_accessor(data_path, total_path + "depthwise_BatchNorm_moving_variance.npy"),
+           get_weights_accessor(data_path, total_path + "depthwise_BatchNorm_beta.npy"),
+           get_weights_accessor(data_path, total_path + "depthwise_BatchNorm_gamma.npy"),
+           0.001f)
+       << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.f))
+       << ConvolutionLayer(
+           1U, 1U, conv_filt,
+           get_weights_accessor(data_path, total_path + "pointwise_weights.npy"),
+           std::unique_ptr<arm_compute::graph::ITensorAccessor>(nullptr),
+           conv_pad_stride_info)
+       << BatchNormalizationLayer(
+           get_weights_accessor(data_path, total_path + "pointwise_BatchNorm_moving_mean.npy"),
+           get_weights_accessor(data_path, total_path + "pointwise_BatchNorm_moving_variance.npy"),
+           get_weights_accessor(data_path, total_path + "pointwise_BatchNorm_beta.npy"),
+           get_weights_accessor(data_path, total_path + "pointwise_BatchNorm_gamma.npy"),
+           0.001f)
+       << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.f));
+
+    return BranchLayer(std::move(sg));
+}
+
+/** Example demonstrating how to implement MobileNet's network using the Compute Library's graph API
+ *
+ * @param[in] argc Number of arguments
+ * @param[in] argv Arguments ( [optional] Path to the weights folder, [optional] image, [optional] labels )
+ */
+void main_graph_mobilenet(int argc, const char **argv)
+{
+    std::string data_path; /* Path to the trainable data */
+    std::string image;     /* Image data */
+    std::string label;     /* Label data */
+
+    constexpr float mean_r = 122.68f; /* Mean value to subtract from red channel */
+    constexpr float mean_g = 116.67f; /* Mean value to subtract from green channel */
+    constexpr float mean_b = 104.01f; /* Mean value to subtract from blue channel */
+
+    // Parse arguments
+    if(argc < 2)
+    {
+        // Print help
+        std::cout << "Usage: " << argv[0] << " [path_to_data] [image] [labels]\n\n";
+        std::cout << "No data folder provided: using random values\n\n";
+    }
+    else if(argc == 2)
+    {
+        data_path = argv[1];
+        std::cout << "Usage: " << argv[0] << " " << argv[1] << " [image] [labels]\n\n";
+        std::cout << "No image provided: using random values\n\n";
+    }
+    else if(argc == 3)
+    {
+        data_path = argv[1];
+        image     = argv[2];
+        std::cout << "Usage: " << argv[0] << " " << argv[1] << " " << argv[2] << " [labels]\n\n";
+        std::cout << "No text file with labels provided: skipping output accessor\n\n";
+    }
+    else
+    {
+        data_path = argv[1];
+        image     = argv[2];
+        label     = argv[3];
+    }
+
+    // Check if OpenCL is available and initialize the scheduler
+    TargetHint hint = TargetHint::NEON;
+    if(Graph::opencl_is_available())
+    {
+        hint = TargetHint::OPENCL;
+    }
+
+    Graph graph;
+    graph << hint
+          << Tensor(TensorInfo(TensorShape(224U, 224U, 3U, 1U), 1, DataType::F32),
+                    get_input_accessor(image, mean_r, mean_g, mean_b))
+          << ConvolutionLayer(
+              3U, 3U, 32U,
+              get_weights_accessor(data_path, "/cnn_data/mobilenet_v1_model/Conv2d_0_weights.npy"),
+              std::unique_ptr<arm_compute::graph::ITensorAccessor>(nullptr),
+              PadStrideInfo(2, 2, 0, 1, 0, 1, DimensionRoundingType::FLOOR))
+          << BatchNormalizationLayer(
+              get_weights_accessor(data_path, "/cnn_data/mobilenet_v1_model/Conv2d_0_BatchNorm_moving_mean.npy"),
+              get_weights_accessor(data_path, "/cnn_data/mobilenet_v1_model/Conv2d_0_BatchNorm_moving_variance.npy"),
+              get_weights_accessor(data_path, "/cnn_data/mobilenet_v1_model/Conv2d_0_BatchNorm_beta.npy"),
+              get_weights_accessor(data_path, "/cnn_data/mobilenet_v1_model/Conv2d_0_BatchNorm_gamma.npy"),
+              0.001f)
+          << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.f))
+          << get_dwsc_node(data_path, "Conv2d_1", 64, PadStrideInfo(1, 1, 1, 1), PadStrideInfo(1, 1, 0, 0))
+          << get_dwsc_node(data_path, "Conv2d_2", 128, PadStrideInfo(2, 2, 0, 1, 0, 1, DimensionRoundingType::FLOOR), PadStrideInfo(1, 1, 0, 0))
+          << get_dwsc_node(data_path, "Conv2d_3", 128, PadStrideInfo(1, 1, 1, 1, 1, 1, DimensionRoundingType::FLOOR), PadStrideInfo(1, 1, 0, 0))
+          << get_dwsc_node(data_path, "Conv2d_4", 256, PadStrideInfo(2, 2, 0, 1, 0, 1, DimensionRoundingType::FLOOR), PadStrideInfo(1, 1, 0, 0))
+          << get_dwsc_node(data_path, "Conv2d_5", 256, PadStrideInfo(1, 1, 1, 1, 1, 1, DimensionRoundingType::FLOOR), PadStrideInfo(1, 1, 0, 0))
+          << get_dwsc_node(data_path, "Conv2d_6", 512, PadStrideInfo(2, 2, 0, 1, 0, 1, DimensionRoundingType::FLOOR), PadStrideInfo(1, 1, 0, 0))
+          << get_dwsc_node(data_path, "Conv2d_7", 512, PadStrideInfo(1, 1, 1, 1, 1, 1, DimensionRoundingType::FLOOR), PadStrideInfo(1, 1, 0, 0))
+          << get_dwsc_node(data_path, "Conv2d_8", 512, PadStrideInfo(1, 1, 1, 1, 1, 1, DimensionRoundingType::FLOOR), PadStrideInfo(1, 1, 0, 0))
+          << get_dwsc_node(data_path, "Conv2d_9", 512, PadStrideInfo(1, 1, 1, 1, 1, 1, DimensionRoundingType::FLOOR), PadStrideInfo(1, 1, 0, 0))
+          << get_dwsc_node(data_path, "Conv2d_10", 512, PadStrideInfo(1, 1, 1, 1, 1, 1, DimensionRoundingType::FLOOR), PadStrideInfo(1, 1, 0, 0))
+          << get_dwsc_node(data_path, "Conv2d_11", 512, PadStrideInfo(1, 1, 1, 1, 1, 1, DimensionRoundingType::FLOOR), PadStrideInfo(1, 1, 0, 0))
+          << get_dwsc_node(data_path, "Conv2d_12", 1024, PadStrideInfo(2, 2, 0, 1, 0, 1, DimensionRoundingType::FLOOR), PadStrideInfo(1, 1, 0, 0))
+          << get_dwsc_node(data_path, "Conv2d_13", 1024, PadStrideInfo(1, 1, 1, 1, 1, 1, DimensionRoundingType::FLOOR), PadStrideInfo(1, 1, 0, 0))
+          << PoolingLayer(PoolingLayerInfo(PoolingType::AVG))
+          << ConvolutionLayer(
+              1U, 1U, 1001U,
+              get_weights_accessor(data_path, "/cnn_data/mobilenet_v1_model/Logits_Conv2d_1c_1x1_weights.npy"),
+              get_weights_accessor(data_path, "/cnn_data/mobilenet_v1_model/Logits_Conv2d_1c_1x1_biases.npy"),
+              PadStrideInfo(1, 1, 0, 0))
+          << ReshapeLayer(TensorShape(1001U))
+          << SoftmaxLayer()
+          << Tensor(get_output_accessor(label, 5));
+
+    graph.run();
+}
+
+/** Main program for MobileNetV1
+ *
+ * @param[in] argc Number of arguments
+ * @param[in] argv Arguments ( [optional] Path to the weights folder, [optional] image, [optional] labels )
+ */
+int main(int argc, const char **argv)
+{
+    return arm_compute::utils::run_example(argc, argv, main_graph_mobilenet);
+}
diff --git a/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp b/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp
index 38e367d..e8882b9 100644
--- a/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp
+++ b/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp
@@ -130,7 +130,8 @@
     _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("batchnormalization_layer", build_opts));
 
     // Set kernel static arguments
-    unsigned int idx = 2 * num_arguments_per_3D_tensor() + 4 * num_arguments_per_1D_tensor(); // Skip the input and output parameters
+    unsigned int include_output = (output != nullptr) ? 1 : 0;
+    unsigned int idx            = (1 + include_output) * num_arguments_per_3D_tensor() + 4 * num_arguments_per_1D_tensor(); // Skip the input and output parameters
     _kernel.setArg<cl_float>(idx++, _epsilon);
 
     // Configure kernel window
@@ -160,7 +161,8 @@
     Window vector_slice = window.first_slice_window_1D();
     vector_slice.set(Window::DimX, Window::Dimension(0, 0, 0));
 
-    unsigned int idx = 2 * num_arguments_per_3D_tensor();
+    unsigned int include_output = (_output != nullptr) ? 1 : 0;
+    unsigned int idx            = (1 + include_output) * num_arguments_per_3D_tensor();
     add_1D_tensor_argument(idx, _mean, vector_slice);
     add_1D_tensor_argument(idx, _var, vector_slice);
     add_1D_tensor_argument(idx, _beta, vector_slice);
diff --git a/src/core/CL/kernels/CLDepthwiseConvolution3x3Kernel.cpp b/src/core/CL/kernels/CLDepthwiseConvolution3x3Kernel.cpp
index be8fae2..e86c55f 100644
--- a/src/core/CL/kernels/CLDepthwiseConvolution3x3Kernel.cpp
+++ b/src/core/CL/kernels/CLDepthwiseConvolution3x3Kernel.cpp
@@ -37,6 +37,29 @@
 
 using namespace arm_compute;
 
+namespace
+{
+/** Calculates expected output shape dimension
+ *
+ * @param[in] Input shape
+ *
+ * @return Expected output shape
+ */
+TensorShape get_output_shape(TensorShape input_shape, TensorShape weights_shape, PadStrideInfo conv_info)
+{
+    unsigned int output_width  = 0;
+    unsigned int output_height = 0;
+
+    std::tie(output_width, output_height) = scaled_dimensions(input_shape.x(), input_shape.y(), weights_shape.x(), weights_shape.y(), conv_info);
+
+    TensorShape output_shape = input_shape;
+    output_shape.set(0, output_width);
+    output_shape.set(1, output_height);
+
+    return output_shape;
+}
+} // namespace
+
 CLDepthwiseConvolution3x3Kernel::CLDepthwiseConvolution3x3Kernel()
     : _border_size(0), _input(), _output(), _weights(), _biases(), _conv_stride_x(0), _conv_stride_y(0), _conv_pad_left(0), _conv_pad_top(0)
 {
@@ -50,9 +73,7 @@
 void CLDepthwiseConvolution3x3Kernel::configure(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const PadStrideInfo &conv_info)
 {
     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F32);
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8, DataType::F32);
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(weights, 1, DataType::QASYMM8, DataType::F32);
-    ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output, weights);
+    ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
     ARM_COMPUTE_ERROR_ON(weights->info()->dimension(0) != 3 || weights->info()->dimension(1) != 3);
 
     if(biases != nullptr)
@@ -69,13 +90,18 @@
         ARM_COMPUTE_ERROR_ON(biases->info()->num_dimensions() > 1);
     }
 
-    std::pair<unsigned int, unsigned int> expected_output = scaled_dimensions(input->info()->tensor_shape().x(), input->info()->tensor_shape().y(),
-                                                                              weights->info()->tensor_shape().x(), weights->info()->tensor_shape().y(),
-                                                                              conv_info);
+    // Get convolved dimensions
+    TensorShape output_shape = get_output_shape(input->info()->tensor_shape(), weights->info()->tensor_shape(), conv_info);
 
-    ARM_COMPUTE_UNUSED(expected_output);
-    ARM_COMPUTE_ERROR_ON(expected_output.first != output->info()->tensor_shape().x());
-    ARM_COMPUTE_ERROR_ON(expected_output.second != output->info()->tensor_shape().y());
+    // Output auto inizialitation if not yet initialized
+    auto_init_if_empty(*output->info(),
+                       output_shape,
+                       1,
+                       input->info()->data_type(),
+                       input->info()->fixed_point_position(),
+                       input->info()->quantization_info());
+
+    ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(output->info()->tensor_shape(), output_shape);
 
     _input         = input;
     _output        = output;
diff --git a/src/graph/nodes/ConvolutionLayer.cpp b/src/graph/nodes/ConvolutionLayer.cpp
index a7236fc..ae4a8d7 100644
--- a/src/graph/nodes/ConvolutionLayer.cpp
+++ b/src/graph/nodes/ConvolutionLayer.cpp
@@ -189,7 +189,7 @@
                                      in->info()->data_type(),
                                      in->info()->fixed_point_position()));
     }
-    if(_biases.tensor() == nullptr)
+    if(_biases.has_accessor() && _biases.tensor() == nullptr)
     {
         _biases.set_info(TensorInfo(TensorShape(_ofm), in->info()->num_channels(), in->info()->data_type(), in->info()->fixed_point_position()));
     }
@@ -200,11 +200,14 @@
 
     // Check if the weights and biases are loaded
     bool weights_are_loaded = _weights.tensor() != nullptr;
-    bool biases_are_loaded  = _weights.tensor() != nullptr;
+    bool biases_are_loaded  = _biases.has_accessor() ? _biases.tensor() != nullptr : true;
 
     // Set bias and weights target
     _weights.set_target(_target_hint);
-    _biases.set_target(_target_hint);
+    if(_biases.has_accessor())
+    {
+        _biases.set_target(_target_hint);
+    }
 
     // Calculate output shape
     TensorShape output_shape = calculate_convolution_layer_output_shape(in->info()->tensor_shape(), _weights.info().tensor_shape(), _conv_info);
diff --git a/src/graph/nodes/DepthwiseConvolutionLayer.cpp b/src/graph/nodes/DepthwiseConvolutionLayer.cpp
index 1c006d6..ceac2a2 100644
--- a/src/graph/nodes/DepthwiseConvolutionLayer.cpp
+++ b/src/graph/nodes/DepthwiseConvolutionLayer.cpp
@@ -51,10 +51,13 @@
     }
 
     bool weights_is_loaded = _weights.tensor() != nullptr;
-    bool biases_is_loaded  = _biases.has_accessor() ? _biases.tensor() != nullptr : false;
+    bool biases_is_loaded  = _biases.has_accessor() ? _biases.tensor() != nullptr : true;
 
     _weights.set_target(_target_hint);
-    _biases.set_target(_target_hint);
+    if(_biases.has_accessor())
+    {
+        _biases.set_target(_target_hint);
+    }
 
     // Create node context
     NodeContext node_ctx(OperationType::DepthwiseConvolutionLayer);
diff --git a/src/graph/nodes/ReshapeLayer.cpp b/src/graph/nodes/ReshapeLayer.cpp
index 4967534..bbe0739 100644
--- a/src/graph/nodes/ReshapeLayer.cpp
+++ b/src/graph/nodes/ReshapeLayer.cpp
@@ -47,11 +47,11 @@
     arm_compute::auto_init_if_empty(*out->info(), _shape, 1, in->info()->data_type(), in->info()->fixed_point_position());
 
     // Create node context
-    NodeContext node_ctx(OperationType::QuantizationLayer);
+    NodeContext node_ctx(OperationType::ReshapeLayer);
     node_ctx.set_target(_target_hint);
     node_ctx.add_input(in);
     node_ctx.add_output(out);
 
     // Get function
-    return OperationRegistry::get().find_operation(OperationType::QuantizationLayer, _target_hint)->configure(node_ctx);
+    return OperationRegistry::get().find_operation(OperationType::ReshapeLayer, _target_hint)->configure(node_ctx);
 }
diff --git a/src/graph/operations/CLSimpleOperations.cpp b/src/graph/operations/CLSimpleOperations.cpp
index 881f491..647f88f 100644
--- a/src/graph/operations/CLSimpleOperations.cpp
+++ b/src/graph/operations/CLSimpleOperations.cpp
@@ -138,7 +138,7 @@
 /* DepthwiseConvolutionLayer Layer */
 REGISTER_SIMPLE_OPERATION(CLDepthwiseConvolutionOperation, OPENCL, OperationType::DepthwiseConvolutionLayer)
 {
-    ARM_COMPUTE_ERROR_ON(ctx.num_inputs() != 2 || ctx.num_inputs() != 3);
+    ARM_COMPUTE_ERROR_ON(ctx.num_inputs() != 2 && ctx.num_inputs() != 3);
     ARM_COMPUTE_ERROR_ON(ctx.num_outputs() != 1);
     ARM_COMPUTE_ERROR_ON(dynamic_cast<arm_compute::ICLTensor *>(ctx.input(0)) == nullptr);
     ARM_COMPUTE_ERROR_ON(dynamic_cast<arm_compute::ICLTensor *>(ctx.output(0)) == nullptr);
diff --git a/src/graph/operations/NESimpleOperations.cpp b/src/graph/operations/NESimpleOperations.cpp
index c77aeec..f234341 100644
--- a/src/graph/operations/NESimpleOperations.cpp
+++ b/src/graph/operations/NESimpleOperations.cpp
@@ -138,7 +138,7 @@
 /* DepthwiseConvolutionLayer Layer */
 REGISTER_SIMPLE_OPERATION(NEDepthwiseConvolutionOperation, NEON, OperationType::DepthwiseConvolutionLayer)
 {
-    ARM_COMPUTE_ERROR_ON(ctx.num_inputs() != 2 || ctx.num_inputs() != 3);
+    ARM_COMPUTE_ERROR_ON(ctx.num_inputs() != 2 && ctx.num_inputs() != 3);
     ARM_COMPUTE_ERROR_ON(ctx.num_outputs() != 1);
     ARM_COMPUTE_ERROR_ON(dynamic_cast<arm_compute::ITensor *>(ctx.input(0)) == nullptr);
     ARM_COMPUTE_ERROR_ON(dynamic_cast<arm_compute::ITensor *>(ctx.output(0)) == nullptr);
diff --git a/src/runtime/CL/functions/CLDepthwiseConvolution.cpp b/src/runtime/CL/functions/CLDepthwiseConvolution.cpp
index a701391..8114950 100644
--- a/src/runtime/CL/functions/CLDepthwiseConvolution.cpp
+++ b/src/runtime/CL/functions/CLDepthwiseConvolution.cpp
@@ -38,8 +38,7 @@
 void CLDepthwiseConvolution3x3::configure(ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const PadStrideInfo &conv_info)
 {
     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F32);
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8, DataType::F32);
-    ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output, weights);
+    ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
 
     _kernel.set_target(CLScheduler::get().target());
     _kernel.configure(input, weights, biases, output, conv_info);
diff --git a/tests/benchmark/CL/SYSTEM/MobileNetV1.cpp b/tests/benchmark/CL/SYSTEM/MobileNetV1.cpp
new file mode 100644
index 0000000..66be323
--- /dev/null
+++ b/tests/benchmark/CL/SYSTEM/MobileNetV1.cpp
@@ -0,0 +1,82 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/core/TensorShape.h"
+#include "arm_compute/core/Types.h"
+#include "arm_compute/runtime/CL/CLTensor.h"
+#include "arm_compute/runtime/CL/CLTensorAllocator.h"
+#include "arm_compute/runtime/CL/functions/CLActivationLayer.h"
+#include "arm_compute/runtime/CL/functions/CLBatchNormalizationLayer.h"
+#include "arm_compute/runtime/CL/functions/CLConvolutionLayer.h"
+#include "arm_compute/runtime/CL/functions/CLDepthwiseConvolution.h"
+#include "arm_compute/runtime/CL/functions/CLDirectConvolutionLayer.h"
+#include "arm_compute/runtime/CL/functions/CLPoolingLayer.h"
+#include "arm_compute/runtime/CL/functions/CLReshapeLayer.h"
+#include "arm_compute/runtime/CL/functions/CLSoftmaxLayer.h"
+#include "tests/CL/CLAccessor.h"
+#include "tests/benchmark/fixtures/MobileNetV1Fixture.h"
+#include "tests/framework/Macros.h"
+#include "tests/framework/datasets/Datasets.h"
+#include "utils/TypePrinter.h"
+
+namespace arm_compute
+{
+namespace test
+{
+using CLMobileNetV1_224_Fixture = MobileNetV1Fixture<CLTensor,
+      CLAccessor,
+      CLActivationLayer,
+      CLBatchNormalizationLayer,
+      CLConvolutionLayer,
+      CLDirectConvolutionLayer,
+      CLDepthwiseConvolution3x3,
+      CLReshapeLayer,
+      CLPoolingLayer,
+      CLSoftmaxLayer,
+      224>;
+
+using CLMobileNetV1_128_Fixture = MobileNetV1Fixture<CLTensor,
+      CLAccessor,
+      CLActivationLayer,
+      CLBatchNormalizationLayer,
+      CLConvolutionLayer,
+      CLDirectConvolutionLayer,
+      CLDepthwiseConvolution3x3,
+      CLReshapeLayer,
+      CLPoolingLayer,
+      CLSoftmaxLayer,
+      128>;
+
+TEST_SUITE(CL)
+TEST_SUITE(SYSTEM_TEST)
+
+REGISTER_FIXTURE_DATA_TEST_CASE(MobileNetV1_224, CLMobileNetV1_224_Fixture, framework::DatasetMode::ALL,
+                                framework::dataset::make("Batches", { 1, 4, 8 }));
+
+REGISTER_FIXTURE_DATA_TEST_CASE(MobileNetV1_128, CLMobileNetV1_128_Fixture, framework::DatasetMode::ALL,
+                                framework::dataset::make("Batches", { 1, 4, 8 }));
+
+TEST_SUITE_END()
+TEST_SUITE_END()
+} // namespace test
+} // namespace arm_compute
diff --git a/tests/benchmark/fixtures/MobileNetV1Fixture.h b/tests/benchmark/fixtures/MobileNetV1Fixture.h
new file mode 100644
index 0000000..07333dd
--- /dev/null
+++ b/tests/benchmark/fixtures/MobileNetV1Fixture.h
@@ -0,0 +1,84 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef ARM_COMPUTE_TEST_MOBILENETV1_FIXTURE
+#define ARM_COMPUTE_TEST_MOBILENETV1_FIXTURE
+
+#include "tests/AssetsLibrary.h"
+#include "tests/Utils.h"
+#include "tests/framework/Fixture.h"
+#include "tests/networks/MobileNetV1Network.h"
+
+namespace arm_compute
+{
+namespace test
+{
+template <typename TensorType,
+          typename Accessor,
+          typename ActivationLayerFunction,
+          typename BatchNormalizationLayerFunction,
+          typename ConvolutionLayerFunction,
+          typename DirectConvolutionLayerFunction,
+          typename DepthwiseConvolutionFunction,
+          typename ReshapeFunction,
+          typename PoolingLayerFunction,
+          typename SoftmaxLayerFunction,
+          unsigned int InputSize>
+class MobileNetV1Fixture : public framework::Fixture
+{
+public:
+    template <typename...>
+    void setup(int batches)
+    {
+        network.init(InputSize, batches);
+        network.build();
+        network.allocate();
+        network.fill_random();
+    }
+
+    void run()
+    {
+        network.run();
+    }
+
+    void teardown()
+    {
+        network.clear();
+    }
+
+private:
+    networks::MobileNetV1Network<TensorType,
+             Accessor,
+             ActivationLayerFunction,
+             BatchNormalizationLayerFunction,
+             ConvolutionLayerFunction,
+             DirectConvolutionLayerFunction,
+             DepthwiseConvolutionFunction,
+             ReshapeFunction,
+             PoolingLayerFunction,
+             SoftmaxLayerFunction>
+             network{};
+};
+} // namespace test
+} // namespace arm_compute
+#endif /* ARM_COMPUTE_TEST_MOBILENETV1_FIXTURE */
diff --git a/tests/networks/MobileNetV1Network.h b/tests/networks/MobileNetV1Network.h
new file mode 100644
index 0000000..dbe3f49
--- /dev/null
+++ b/tests/networks/MobileNetV1Network.h
@@ -0,0 +1,377 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef __ARM_COMPUTE_TEST_MODEL_OBJECTS_MOBILENETV1_H__
+#define __ARM_COMPUTE_TEST_MODEL_OBJECTS_MOBILENETV1_H__
+
+#include "tests/AssetsLibrary.h"
+#include "tests/Globals.h"
+#include "tests/Utils.h"
+
+#include "utils/Utils.h"
+
+#include <memory>
+
+using namespace arm_compute;
+using namespace arm_compute::test;
+
+namespace arm_compute
+{
+namespace test
+{
+namespace networks
+{
+/** MobileNet model object */
+template <typename TensorType,
+          typename Accessor,
+          typename ActivationLayerFunction,
+          typename BatchNormalizationLayerFunction,
+          typename ConvolutionLayerFunction,
+          typename DirectConvolutionLayerFunction,
+          typename DepthwiseConvolutionFunction,
+          typename ReshapeFunction,
+          typename PoolingLayerFunction,
+          typename SoftmaxLayerFunction>
+class MobileNetV1Network
+{
+public:
+    void init(unsigned int input_spatial_size, int batches)
+    {
+        _batches            = batches;
+        _input_spatial_size = input_spatial_size;
+
+        // Currently supported sizes
+        ARM_COMPUTE_ERROR_ON(input_spatial_size != 128 && input_spatial_size != 224);
+
+        // Initialize input, output
+        input.allocator()->init(TensorInfo(TensorShape(input_spatial_size, input_spatial_size, 3U, _batches), 1, DataType::F32));
+        output.allocator()->init(TensorInfo(TensorShape(1001U, _batches), 1, DataType::F32));
+        // Initialize weights and biases
+        w_conv3x3.allocator()->init(TensorInfo(TensorShape(3U, 3U, 3U, 32U), 1, DataType::F32));
+        mean_conv3x3.allocator()->init(TensorInfo(TensorShape(32U), 1, DataType::F32));
+        var_conv3x3.allocator()->init(TensorInfo(TensorShape(32U), 1, DataType::F32));
+        beta_conv3x3.allocator()->init(TensorInfo(TensorShape(32U), 1, DataType::F32));
+        gamma_conv3x3.allocator()->init(TensorInfo(TensorShape(32U), 1, DataType::F32));
+        depthwise_conv_block_init(0, 32, 32);
+        depthwise_conv_block_init(1, 32, 64);
+        depthwise_conv_block_init(2, 64, 64);
+        depthwise_conv_block_init(3, 64, 128);
+        depthwise_conv_block_init(4, 128, 256);
+        depthwise_conv_block_init(5, 256, 512);
+        depthwise_conv_block_init(6, 512, 512);
+        depthwise_conv_block_init(7, 512, 512);
+        depthwise_conv_block_init(8, 512, 512);
+        depthwise_conv_block_init(9, 512, 512);
+        depthwise_conv_block_init(10, 512, 512);
+        depthwise_conv_block_init(11, 512, 1024);
+        depthwise_conv_block_init(12, 1024, 1024);
+        w_conv1c.allocator()->init(TensorInfo(TensorShape(1U, 1U, 1024U, 1001U), 1, DataType::F32));
+        b_conv1c.allocator()->init(TensorInfo(TensorShape(1001U), 1, DataType::F32));
+        // Init reshaped output
+        reshape_out.allocator()->init(TensorInfo(TensorShape(1001U, _batches), 1, DataType::F32));
+    }
+
+    /** Build the model. */
+    void build()
+    {
+        // Configure Layers
+        conv3x3.configure(&input, &w_conv3x3, nullptr, &conv_out[0], PadStrideInfo(2, 2, 0, 1, 0, 1, DimensionRoundingType::FLOOR));
+        conv3x3_bn.configure(&conv_out[0], nullptr, &mean_conv3x3, &var_conv3x3, &beta_conv3x3, &gamma_conv3x3, 0.001f);
+        conv3x3_act.configure(&conv_out[0], nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.f));
+        depthwise_conv_block_build(0, PadStrideInfo(1, 1, 1, 1), PadStrideInfo(1, 1, 0, 0));
+        depthwise_conv_block_build(1, PadStrideInfo(2, 2, 0, 1, 0, 1, DimensionRoundingType::FLOOR), PadStrideInfo(1, 1, 0, 0));
+        depthwise_conv_block_build(2, PadStrideInfo(1, 1, 1, 1, 1, 1, DimensionRoundingType::FLOOR), PadStrideInfo(1, 1, 0, 0));
+        depthwise_conv_block_build(3, PadStrideInfo(2, 2, 0, 1, 0, 1, DimensionRoundingType::FLOOR), PadStrideInfo(1, 1, 0, 0));
+        depthwise_conv_block_build(4, PadStrideInfo(1, 1, 1, 1, 1, 1, DimensionRoundingType::FLOOR), PadStrideInfo(1, 1, 0, 0));
+        depthwise_conv_block_build(5, PadStrideInfo(2, 2, 0, 1, 0, 1, DimensionRoundingType::FLOOR), PadStrideInfo(1, 1, 0, 0));
+        depthwise_conv_block_build(6, PadStrideInfo(1, 1, 1, 1, 1, 1, DimensionRoundingType::FLOOR), PadStrideInfo(1, 1, 0, 0));
+        depthwise_conv_block_build(7, PadStrideInfo(1, 1, 1, 1, 1, 1, DimensionRoundingType::FLOOR), PadStrideInfo(1, 1, 0, 0));
+        depthwise_conv_block_build(8, PadStrideInfo(1, 1, 1, 1, 1, 1, DimensionRoundingType::FLOOR), PadStrideInfo(1, 1, 0, 0));
+        depthwise_conv_block_build(9, PadStrideInfo(1, 1, 1, 1, 1, 1, DimensionRoundingType::FLOOR), PadStrideInfo(1, 1, 0, 0));
+        depthwise_conv_block_build(10, PadStrideInfo(1, 1, 1, 1, 1, 1, DimensionRoundingType::FLOOR), PadStrideInfo(1, 1, 0, 0));
+        depthwise_conv_block_build(11, PadStrideInfo(2, 2, 0, 1, 0, 1, DimensionRoundingType::FLOOR), PadStrideInfo(1, 1, 0, 0));
+        depthwise_conv_block_build(12, PadStrideInfo(1, 1, 1, 1, 1, 1, DimensionRoundingType::FLOOR), PadStrideInfo(1, 1, 0, 0));
+        pool.configure(&conv_out[13], &pool_out, PoolingLayerInfo(PoolingType::AVG));
+        conv1c.configure(&pool_out, &w_conv1c, &b_conv1c, &conv_out[14], PadStrideInfo(1, 1, 0, 0));
+        reshape.configure(&conv_out[14], &reshape_out);
+        smx.configure(&reshape_out, &output);
+    }
+
+    void allocate()
+    {
+        input.allocator()->allocate();
+        output.allocator()->allocate();
+
+        w_conv3x3.allocator()->allocate();
+        mean_conv3x3.allocator()->allocate();
+        var_conv3x3.allocator()->allocate();
+        beta_conv3x3.allocator()->allocate();
+        gamma_conv3x3.allocator()->allocate();
+
+        ARM_COMPUTE_ERROR_ON(w_conv.size() != w_dwc.size());
+        for(unsigned int i = 0; i < w_conv.size(); ++i)
+        {
+            w_dwc[i].allocator()->allocate();
+            bn_mean[2 * i].allocator()->allocate();
+            bn_var[2 * i].allocator()->allocate();
+            bn_beta[2 * i].allocator()->allocate();
+            bn_gamma[2 * i].allocator()->allocate();
+            w_conv[i].allocator()->allocate();
+            bn_mean[2 * i + 1].allocator()->allocate();
+            bn_var[2 * i + 1].allocator()->allocate();
+            bn_beta[2 * i + 1].allocator()->allocate();
+            bn_gamma[2 * i + 1].allocator()->allocate();
+        }
+        w_conv1c.allocator()->allocate();
+        b_conv1c.allocator()->allocate();
+
+        // Allocate intermediate buffers
+        for(auto &o : conv_out)
+        {
+            o.allocator()->allocate();
+        }
+        for(auto &o : dwc_out)
+        {
+            o.allocator()->allocate();
+        }
+        pool_out.allocator()->allocate();
+        reshape_out.allocator()->allocate();
+    }
+
+    /** Fills the trainable parameters and input with random data. */
+    void fill_random()
+    {
+        unsigned int                     seed_idx = 0;
+        std::uniform_real_distribution<> distribution(-1, 1);
+        library->fill(Accessor(input), distribution, seed_idx++);
+
+        library->fill(Accessor(w_conv3x3), distribution, seed_idx++);
+        library->fill(Accessor(mean_conv3x3), distribution, seed_idx++);
+        library->fill(Accessor(var_conv3x3), distribution, seed_idx++);
+        library->fill(Accessor(beta_conv3x3), distribution, seed_idx++);
+        library->fill(Accessor(gamma_conv3x3), distribution, seed_idx++);
+
+        ARM_COMPUTE_ERROR_ON(w_conv.size() != w_dwc.size());
+        for(unsigned int i = 0; i < w_conv.size(); ++i)
+        {
+            library->fill(Accessor(w_dwc[i]), distribution, seed_idx++);
+            library->fill(Accessor(bn_mean[2 * i]), distribution, seed_idx++);
+            library->fill(Accessor(bn_var[2 * i]), distribution, seed_idx++);
+            library->fill(Accessor(bn_beta[2 * i]), distribution, seed_idx++);
+            library->fill(Accessor(bn_gamma[2 * i]), distribution, seed_idx++);
+            library->fill(Accessor(w_conv[i]), distribution, seed_idx++);
+            library->fill(Accessor(bn_mean[2 * i + 1]), distribution, seed_idx++);
+            library->fill(Accessor(bn_var[2 * i + 1]), distribution, seed_idx++);
+            library->fill(Accessor(bn_beta[2 * i + 1]), distribution, seed_idx++);
+            library->fill(Accessor(bn_gamma[2 * i + 1]), distribution, seed_idx++);
+        }
+        library->fill(Accessor(w_conv1c), distribution, seed_idx++);
+        library->fill(Accessor(b_conv1c), distribution, seed_idx++);
+    }
+
+    /** Feed input to network from file.
+     *
+     * @param name File name of containing the input data.
+     */
+    void feed(std::string name)
+    {
+        library->fill_layer_data(Accessor(input), name);
+    }
+
+    /** Get the classification results.
+     *
+     * @return Vector containing the classified labels
+     */
+    std::vector<unsigned int> get_classifications()
+    {
+        std::vector<unsigned int> classified_labels;
+        Accessor                  output_accessor(output);
+
+        Window window;
+        window.set(Window::DimX, Window::Dimension(0, 1, 1));
+        for(unsigned int d = 1; d < output_accessor.shape().num_dimensions(); ++d)
+        {
+            window.set(d, Window::Dimension(0, output_accessor.shape()[d], 1));
+        }
+
+        execute_window_loop(window, [&](const Coordinates & id)
+        {
+            int               max_idx = 0;
+            float             val     = 0;
+            const void *const out_ptr = output_accessor(id);
+            for(unsigned int l = 0; l < output_accessor.shape().x(); ++l)
+            {
+                float curr_val = reinterpret_cast<const float *>(out_ptr)[l];
+                if(curr_val > val)
+                {
+                    max_idx = l;
+                    val     = curr_val;
+                }
+            }
+            classified_labels.push_back(max_idx);
+        });
+        return classified_labels;
+    }
+
+    /** Clear all allocated memory from the tensor objects */
+    void clear()
+    {
+        input.allocator()->free();
+        output.allocator()->free();
+
+        w_conv3x3.allocator()->free();
+        mean_conv3x3.allocator()->free();
+        var_conv3x3.allocator()->free();
+        beta_conv3x3.allocator()->free();
+        gamma_conv3x3.allocator()->free();
+
+        ARM_COMPUTE_ERROR_ON(w_conv.size() != w_dwc.size());
+        for(unsigned int i = 0; i < w_conv.size(); ++i)
+        {
+            w_dwc[i].allocator()->free();
+            bn_mean[2 * i].allocator()->free();
+            bn_var[2 * i].allocator()->free();
+            bn_beta[2 * i].allocator()->free();
+            bn_gamma[2 * i].allocator()->free();
+            w_conv[i].allocator()->free();
+            bn_mean[2 * i + 1].allocator()->free();
+            bn_var[2 * i + 1].allocator()->free();
+            bn_beta[2 * i + 1].allocator()->free();
+            bn_gamma[2 * i + 1].allocator()->free();
+        }
+        w_conv1c.allocator()->free();
+        b_conv1c.allocator()->free();
+
+        // Free intermediate buffers
+        for(auto &o : conv_out)
+        {
+            o.allocator()->free();
+        }
+        for(auto &o : dwc_out)
+        {
+            o.allocator()->free();
+        }
+        pool_out.allocator()->free();
+        reshape_out.allocator()->free();
+    }
+
+    /** Runs the model */
+    void run()
+    {
+        conv3x3.run();
+        conv3x3_bn.run();
+        conv3x3_act.run();
+        depthwise_conv_block_run(0);
+        depthwise_conv_block_run(1);
+        depthwise_conv_block_run(2);
+        depthwise_conv_block_run(3);
+        depthwise_conv_block_run(4);
+        depthwise_conv_block_run(5);
+        depthwise_conv_block_run(6);
+        depthwise_conv_block_run(7);
+        depthwise_conv_block_run(8);
+        depthwise_conv_block_run(9);
+        depthwise_conv_block_run(10);
+        depthwise_conv_block_run(11);
+        depthwise_conv_block_run(12);
+        pool.run();
+        conv1c.run();
+        reshape.run();
+        smx.run();
+    }
+
+private:
+    void depthwise_conv_block_init(unsigned int idx, unsigned int ifm, unsigned int ofm)
+    {
+        // Depthwise Convolution weights
+        w_dwc[idx].allocator()->init(TensorInfo(TensorShape(3U, 3U, ifm), 1, DataType::F32));
+        // Batch normalization parameters
+        bn_mean[2 * idx].allocator()->init(TensorInfo(TensorShape(ifm), 1, DataType::F32));
+        bn_var[2 * idx].allocator()->init(TensorInfo(TensorShape(ifm), 1, DataType::F32));
+        bn_beta[2 * idx].allocator()->init(TensorInfo(TensorShape(ifm), 1, DataType::F32));
+        bn_gamma[2 * idx].allocator()->init(TensorInfo(TensorShape(ifm), 1, DataType::F32));
+        // Convolution weights
+        w_conv[idx].allocator()->init(TensorInfo(TensorShape(1U, 1U, ifm, ofm), 1, DataType::F32));
+        // Batch normalization parameters
+        bn_mean[2 * idx + 1].allocator()->init(TensorInfo(TensorShape(ofm), 1, DataType::F32));
+        bn_var[2 * idx + 1].allocator()->init(TensorInfo(TensorShape(ofm), 1, DataType::F32));
+        bn_beta[2 * idx + 1].allocator()->init(TensorInfo(TensorShape(ofm), 1, DataType::F32));
+        bn_gamma[2 * idx + 1].allocator()->init(TensorInfo(TensorShape(ofm), 1, DataType::F32));
+    }
+    void depthwise_conv_block_build(unsigned int idx, PadStrideInfo dwc_ps, PadStrideInfo conv_ps)
+    {
+        // Configure depthwise convolution block
+        dwc3x3[idx].configure(&conv_out[idx], &w_dwc[idx], nullptr, &dwc_out[idx], dwc_ps);
+        bn[2 * idx].configure(&dwc_out[idx], nullptr, &bn_mean[2 * idx], &bn_var[2 * idx], &bn_beta[2 * idx], &bn_gamma[2 * idx], 0.001f);
+        act[2 * idx].configure(&dwc_out[idx], nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.f));
+        // Configure pointwise convolution block
+        conv1x1[idx].configure(&dwc_out[idx], &w_conv[idx], nullptr, &conv_out[idx + 1], conv_ps);
+        bn[2 * idx + 1].configure(&conv_out[idx + 1], nullptr, &bn_mean[2 * idx + 1], &bn_var[2 * idx + 1], &bn_beta[2 * idx + 1], &bn_gamma[2 * idx + 1], 0.001f);
+        act[2 * idx + 1].configure(&conv_out[idx], nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.f));
+    }
+    void depthwise_conv_block_run(unsigned int idx)
+    {
+        dwc3x3[idx].run();
+        bn[2 * idx].run();
+        act[2 * idx].run();
+        conv1x1[idx].run();
+        bn[2 * idx + 1].run();
+        act[2 * idx + 1].run();
+    }
+
+private:
+    unsigned int _batches{ 0 };
+    unsigned int _input_spatial_size{ 0 };
+
+    ConvolutionLayerFunction        conv3x3{};
+    BatchNormalizationLayerFunction conv3x3_bn{};
+    ActivationLayerFunction         conv3x3_act{};
+    std::array<ActivationLayerFunction, 26>         act{ {} };
+    std::array<BatchNormalizationLayerFunction, 26> bn{ {} };
+    std::array<DepthwiseConvolutionFunction, 13>    dwc3x3{ {} };
+    std::array<DirectConvolutionLayerFunction, 13>  conv1x1{ {} };
+    DirectConvolutionLayerFunction conv1c{};
+    PoolingLayerFunction           pool{};
+    ReshapeFunction                reshape{};
+    SoftmaxLayerFunction           smx{};
+
+    TensorType w_conv3x3{}, mean_conv3x3{}, var_conv3x3{}, beta_conv3x3{}, gamma_conv3x3{};
+    std::array<TensorType, 13> w_conv{ {} };
+    std::array<TensorType, 13> w_dwc{ {} };
+    std::array<TensorType, 26> bn_mean{ {} };
+    std::array<TensorType, 26> bn_var{ {} };
+    std::array<TensorType, 26> bn_beta{ {} };
+    std::array<TensorType, 26> bn_gamma{ {} };
+    TensorType w_conv1c{}, b_conv1c{};
+
+    TensorType input{}, output{};
+
+    std::array<TensorType, 15> conv_out{ {} };
+    std::array<TensorType, 13> dwc_out{ {} };
+    TensorType pool_out{};
+    TensorType reshape_out{};
+};
+} // namespace networks
+} // namespace test
+} // namespace arm_compute
+#endif //__ARM_COMPUTE_TEST_MODEL_OBJECTS_MOBILENETV1_H__