blob: 1f5094e1438843ca79587625fc0e083864c0a5bc [file] [log] [blame]
surmeh013537c2c2018-05-18 16:31:43 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
surmeh013537c2c2018-05-18 16:31:43 +01004//
5
6#include "ClConvolution2dBaseWorkload.hpp"
David Beck711fa312018-09-24 10:46:38 +01007#include <backends/ClLayerSupport.hpp>
8#include <backends/ClTensorHandle.hpp>
9#include <backends/aclCommon/ArmComputeUtils.hpp>
10#include <backends/aclCommon/ArmComputeTensorUtils.hpp>
surmeh013537c2c2018-05-18 16:31:43 +010011
Matthew Bentham14e46692018-09-20 15:35:30 +010012#include <arm_compute/runtime/CL/functions/CLConvolutionLayer.h>
13
surmeh013537c2c2018-05-18 16:31:43 +010014namespace armnn
15{
16using namespace armcomputetensorutils;
17
18arm_compute::Status ClConvolution2dWorkloadValidate(const TensorInfo& input,
19 const TensorInfo& output,
20 const Convolution2dDescriptor& descriptor,
21 const TensorInfo& weights,
arovir01a6824102018-08-28 17:40:45 +010022 const boost::optional<TensorInfo>& biases)
surmeh013537c2c2018-05-18 16:31:43 +010023{
Francis Murtagh351d13d2018-09-24 15:01:18 +010024 const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input, descriptor.m_DataLayout);
25 const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output, descriptor.m_DataLayout);
26 const arm_compute::TensorInfo aclWeightsInfo = BuildArmComputeTensorInfo(weights, descriptor.m_DataLayout);
arovir01a6824102018-08-28 17:40:45 +010027
surmeh013537c2c2018-05-18 16:31:43 +010028 arm_compute::TensorInfo aclBiasesInfo;
29 arm_compute::TensorInfo *optionalAclBiasesInfo = nullptr;
30
31 if (descriptor.m_BiasEnabled)
32 {
arovir01a6824102018-08-28 17:40:45 +010033 BOOST_ASSERT(biases.is_initialized());
34
Francis Murtagh351d13d2018-09-24 15:01:18 +010035 aclBiasesInfo = BuildArmComputeTensorInfo(biases.get(), descriptor.m_DataLayout);
surmeh013537c2c2018-05-18 16:31:43 +010036 optionalAclBiasesInfo = &aclBiasesInfo;
37 }
38
39 arm_compute::PadStrideInfo layerInfo = BuildArmComputePadStrideInfo(descriptor);
40
41 return arm_compute::CLConvolutionLayer::validate(&aclInputInfo,
42 &aclWeightsInfo,
43 optionalAclBiasesInfo,
44 &aclOutputInfo,
45 layerInfo);
46}
47
48}