blob: aa535adec987ffc1f1c7c53a25e87df877a2f290 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5
6#include "NeonDepthwiseConvolutionBaseWorkload.hpp"
7
David Beck711fa312018-09-24 10:46:38 +01008#include <backends/aclCommon/ArmComputeTensorUtils.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +01009
10namespace armnn
11{
12
13arm_compute::Status NeonDepthwiseConvolutionWorkloadValidate(const TensorInfo& input,
14 const TensorInfo& output,
15 const DepthwiseConvolution2dDescriptor& descriptor,
16 const TensorInfo& weights,
David Beck5eec11d2018-10-04 15:43:17 +010017 const Optional<TensorInfo>& biases)
telsoa01c577f2c2018-08-31 09:22:23 +010018{
19 const arm_compute::TensorInfo aclInputInfo =
Nikhil Raja05c2102018-09-25 16:16:13 +010020 armcomputetensorutils::BuildArmComputeTensorInfo(input, descriptor.m_DataLayout);
telsoa01c577f2c2018-08-31 09:22:23 +010021 const arm_compute::TensorInfo aclOutputInfo =
Nikhil Raja05c2102018-09-25 16:16:13 +010022 armcomputetensorutils::BuildArmComputeTensorInfo(output, descriptor.m_DataLayout);
telsoa01c577f2c2018-08-31 09:22:23 +010023 const arm_compute::TensorInfo aclWeightsInfo =
Nikhil Raja05c2102018-09-25 16:16:13 +010024 armcomputetensorutils::BuildArmComputeTensorInfo(weights, descriptor.m_DataLayout);
telsoa01c577f2c2018-08-31 09:22:23 +010025
26 arm_compute::TensorInfo aclBiasesInfo;
27 arm_compute::TensorInfo *optionalAclBiasesInfo = nullptr;
arovir01a6824102018-08-28 17:40:45 +010028
telsoa01c577f2c2018-08-31 09:22:23 +010029 if (descriptor.m_BiasEnabled)
30 {
David Beck5eec11d2018-10-04 15:43:17 +010031 BOOST_ASSERT(biases.has_value());
arovir01a6824102018-08-28 17:40:45 +010032
David Beck5eec11d2018-10-04 15:43:17 +010033 aclBiasesInfo = armcomputetensorutils::BuildArmComputeTensorInfo(biases.value(), descriptor.m_DataLayout);
telsoa01c577f2c2018-08-31 09:22:23 +010034 optionalAclBiasesInfo = &aclBiasesInfo;
35 }
36
37 const arm_compute::PadStrideInfo aclPadStrideInfo =
38 armcomputetensorutils::BuildArmComputePadStrideInfo(descriptor);
39 const unsigned int aclDepthMultiplier = weights.GetShape()[0];
40
41 return arm_compute::NEDepthwiseConvolutionLayer::validate(&aclInputInfo,
42 &aclWeightsInfo,
43 optionalAclBiasesInfo,
44 &aclOutputInfo,
45 aclPadStrideInfo,
46 aclDepthMultiplier);
47}
48
arovir01a6824102018-08-28 17:40:45 +010049}