blob: 14fbdf3dd921a4a834393f4ee06f869ae593c524 [file] [log] [blame]
Sadik Armaganfabc2892019-05-31 09:05:11 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "NeonQuantizeWorkload.hpp"
7#include "NeonWorkloadUtils.hpp"
8
9#include <neon/NeonTensorHandle.hpp>
10#include <aclCommon/ArmComputeTensorUtils.hpp>
Jan Eilers3c9e0452020-04-10 13:00:44 +010011#include <armnn/utility/PolymorphicDowncast.hpp>
Sadik Armaganfabc2892019-05-31 09:05:11 +010012#include <arm_compute/core/Types.h>
13
Sadik Armaganfabc2892019-05-31 09:05:11 +010014namespace armnn
15{
16using namespace armcomputetensorutils;
17
18arm_compute::Status NeonQuantizeWorkloadValidate(const TensorInfo& input, const TensorInfo& output)
19{
20 const arm_compute::TensorInfo neonInputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(input);
21 const arm_compute::TensorInfo neonOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output);
22
23 return arm_compute::NEQuantizationLayer::validate(&neonInputInfo, &neonOutputInfo);
24}
25
26NeonQuantizeWorkload::NeonQuantizeWorkload(const QuantizeQueueDescriptor& descriptor,
27 const WorkloadInfo& workloadInfo)
28 : BaseWorkload<QuantizeQueueDescriptor>(descriptor, workloadInfo)
29{
Keith Davisa8565012020-02-14 12:22:40 +000030 m_Data.ValidateInputsOutputs("NeonQuantizeWorkload", 1, 1);
31
Jan Eilers3c9e0452020-04-10 13:00:44 +010032 arm_compute::ITensor& input = PolymorphicPointerDowncast<IAclTensorHandle>(
Sadik Armaganfabc2892019-05-31 09:05:11 +010033 m_Data.m_Inputs[0])->GetTensor();
Jan Eilers3c9e0452020-04-10 13:00:44 +010034 arm_compute::ITensor& output = PolymorphicPointerDowncast<IAclTensorHandle>(
Sadik Armaganfabc2892019-05-31 09:05:11 +010035 m_Data.m_Outputs[0])->GetTensor();
36
37 m_Layer.reset(new arm_compute::NEQuantizationLayer());
38 m_Layer->configure(&input, &output);
39 m_Layer->prepare();
40}
41
42void NeonQuantizeWorkload::Execute() const
43{
44 if (m_Layer)
45 {
46 ARMNN_SCOPED_PROFILING_EVENT_NEON("NeonQuantizeWorkload_Execute");
47 m_Layer->run();
48 }
49}
50
51} // namespace armnn