blob: ec6c97700b4c6fcfbda42c35f80cb8d742c8dc33 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5
6#include "NeonDepthwiseConvolutionBaseWorkload.hpp"
7
David Beck711fa312018-09-24 10:46:38 +01008#include <backends/aclCommon/ArmComputeTensorUtils.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +01009
10namespace armnn
11{
12
13arm_compute::Status NeonDepthwiseConvolutionWorkloadValidate(const TensorInfo& input,
14 const TensorInfo& output,
15 const DepthwiseConvolution2dDescriptor& descriptor,
16 const TensorInfo& weights,
arovir01a6824102018-08-28 17:40:45 +010017 const boost::optional<TensorInfo>& biases)
telsoa01c577f2c2018-08-31 09:22:23 +010018{
19 const arm_compute::TensorInfo aclInputInfo =
20 armcomputetensorutils::BuildArmComputeTensorInfo(input);
21 const arm_compute::TensorInfo aclOutputInfo =
22 armcomputetensorutils::BuildArmComputeTensorInfo(output);
23 const arm_compute::TensorInfo aclWeightsInfo =
24 armcomputetensorutils::BuildArmComputeTensorInfo(weights);
25
26 arm_compute::TensorInfo aclBiasesInfo;
27 arm_compute::TensorInfo *optionalAclBiasesInfo = nullptr;
arovir01a6824102018-08-28 17:40:45 +010028
telsoa01c577f2c2018-08-31 09:22:23 +010029 if (descriptor.m_BiasEnabled)
30 {
arovir01a6824102018-08-28 17:40:45 +010031 BOOST_ASSERT(biases.is_initialized());
32
33 aclBiasesInfo = armcomputetensorutils::BuildArmComputeTensorInfo(biases.get());
telsoa01c577f2c2018-08-31 09:22:23 +010034 optionalAclBiasesInfo = &aclBiasesInfo;
35 }
36
37 const arm_compute::PadStrideInfo aclPadStrideInfo =
38 armcomputetensorutils::BuildArmComputePadStrideInfo(descriptor);
39 const unsigned int aclDepthMultiplier = weights.GetShape()[0];
40
41 return arm_compute::NEDepthwiseConvolutionLayer::validate(&aclInputInfo,
42 &aclWeightsInfo,
43 optionalAclBiasesInfo,
44 &aclOutputInfo,
45 aclPadStrideInfo,
46 aclDepthMultiplier);
47}
48
arovir01a6824102018-08-28 17:40:45 +010049}