blob: 8b229a1cdaed5f868ff48152ab1055dda565269b [file] [log] [blame]
Narumol Prangnawarat01961a72019-05-30 16:47:12 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "NeonDequantizeWorkload.hpp"
7
8#include "NeonWorkloadUtils.hpp"
9
Matthew Bentham5e98b012020-01-24 23:11:43 +000010#include <arm_compute/runtime/NEON/functions/NEDequantizationLayer.h>
11
Narumol Prangnawarat01961a72019-05-30 16:47:12 +010012#include <aclCommon/ArmComputeTensorUtils.hpp>
13#include <backendsCommon/CpuTensorHandle.hpp>
14#include <neon/NeonTensorHandle.hpp>
15
16namespace armnn
17{
18
19using namespace armcomputetensorutils;
20
21arm_compute::Status NeonDequantizeWorkloadValidate(const TensorInfo& input,
22 const TensorInfo& output)
23{
24 const arm_compute::TensorInfo aclInput = BuildArmComputeTensorInfo(input);
25 const arm_compute::TensorInfo aclOutput = BuildArmComputeTensorInfo(output);
26
27 return arm_compute::NEDequantizationLayer::validate(&aclInput, &aclOutput);
28}
29
30NeonDequantizeWorkload::NeonDequantizeWorkload(const DequantizeQueueDescriptor& descriptor, const WorkloadInfo& info)
31 : BaseWorkload<DequantizeQueueDescriptor>(descriptor, info)
32{
33 m_Data.ValidateInputsOutputs("NeonDequantizeWorkload", 1, 1);
34
Derek Lambertic81855f2019-06-13 17:34:19 +010035 arm_compute::ITensor& input = boost::polymorphic_downcast<IAclTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
36 arm_compute::ITensor& output = boost::polymorphic_downcast<IAclTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
Narumol Prangnawarat01961a72019-05-30 16:47:12 +010037
Matthew Bentham5e98b012020-01-24 23:11:43 +000038 std::unique_ptr<arm_compute::NEDequantizationLayer> layer(new arm_compute::NEDequantizationLayer());
39 layer->configure(&input, &output);
40 layer->prepare();
41 m_Layer.reset(layer.release());
Narumol Prangnawarat01961a72019-05-30 16:47:12 +010042}
43
44void NeonDequantizeWorkload::Execute() const
45{
46 ARMNN_SCOPED_PROFILING_EVENT_NEON("NeonDequantizeWorkload_Execute");
47 m_Layer->run();
48}
49
50} //namespace armnn
51