blob: eb63900380e6363d0b512daaa7322873f9a542b3 [file] [log] [blame]
Jim Flynn983daec2019-05-29 16:20:16 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ClDequantizeWorkload.hpp"
7#include "ClWorkloadUtils.hpp"
8
9#include <aclCommon/ArmComputeTensorUtils.hpp>
Jan Eilers3c9e0452020-04-10 13:00:44 +010010#include <armnn/utility/PolymorphicDowncast.hpp>
Jim Flynn983daec2019-05-29 16:20:16 +010011#include <backendsCommon/CpuTensorHandle.hpp>
12
13#include <arm_compute/core/Types.h>
14
15#include <cl/ClLayerSupport.hpp>
16#include <cl/ClTensorHandle.hpp>
17
Jim Flynn983daec2019-05-29 16:20:16 +010018namespace armnn
19{
20using namespace armcomputetensorutils;
21
22arm_compute::Status ClDequantizeWorkloadValidate(const TensorInfo& input, const TensorInfo& output)
23{
24 const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input);
25 const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output);
26
27 return arm_compute::CLDequantizationLayer::validate(&aclInputInfo, &aclOutputInfo);
28}
29
30ClDequantizeWorkload::ClDequantizeWorkload(const DequantizeQueueDescriptor& descriptor,
31 const WorkloadInfo& workloadInfo)
32 : BaseWorkload<DequantizeQueueDescriptor>(descriptor, workloadInfo)
33{
Keith Davisa8565012020-02-14 12:22:40 +000034 m_Data.ValidateInputsOutputs("ClDequantizeWorkload", 1, 1);
35
Jan Eilers3c9e0452020-04-10 13:00:44 +010036 arm_compute::ICLTensor& input = armnn::PolymorphicPointerDowncast<IClTensorHandle>(
Jim Flynn983daec2019-05-29 16:20:16 +010037 m_Data.m_Inputs[0])->GetTensor();
38
Jan Eilers3c9e0452020-04-10 13:00:44 +010039 arm_compute::ICLTensor& output = armnn::PolymorphicPointerDowncast<IClTensorHandle>(
Jim Flynn983daec2019-05-29 16:20:16 +010040 m_Data.m_Outputs[0])->GetTensor();
41
42 m_Layer.reset(new arm_compute::CLDequantizationLayer());
43 m_Layer->configure(&input, &output);
44 m_Layer->prepare();
45}
46
47void ClDequantizeWorkload::Execute() const
48{
49 if (m_Layer)
50 {
51 ARMNN_SCOPED_PROFILING_EVENT_CL("ClDequantizeWorkload_Execute");
52 m_Layer->run();
53 }
54}
55
56} // namespace armnn