blob: 08f0390e34d55b48d731b6024ee1709adc5d85f4 [file] [log] [blame]
Nattapat Chaimanowong233b3d62018-10-12 12:02:18 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "NeonConstantWorkload.hpp"
7
8#include <arm_compute/core/Types.h>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <Half.hpp>
10#include <aclCommon/ArmComputeTensorUtils.hpp>
11#include <neon/NeonTensorHandle.hpp>
12#include <backendsCommon/CpuTensorHandle.hpp>
13#include <backendsCommon/Workload.hpp>
Nattapat Chaimanowong233b3d62018-10-12 12:02:18 +010014
15#include <boost/cast.hpp>
16
17namespace armnn
18{
19
20NeonConstantWorkload::NeonConstantWorkload(const ConstantQueueDescriptor& descriptor,
21 const WorkloadInfo& info)
22 : BaseWorkload<ConstantQueueDescriptor>(descriptor, info)
23 , m_RanOnce(false)
24{
25}
26
27void NeonConstantWorkload::Execute() const
28{
29 ARMNN_SCOPED_PROFILING_EVENT_NEON("NeonConstantWorkload_Execute");
30
31 using namespace armcomputetensorutils;
32
33 // The intermediate tensor held by the corresponding layer output handler can be initialised with the
34 // given data on the first inference, then reused for subsequent inferences.
35 // The initialisation cannot happen at workload construction time since the ACL kernel for the next layer
36 // may not have been configured at the time.
37 if (!m_RanOnce)
38 {
39 const ConstantQueueDescriptor& data = this->m_Data;
40
41 BOOST_ASSERT(data.m_LayerOutput != nullptr);
42 arm_compute::ITensor& output =
43 boost::polymorphic_downcast<NeonTensorHandle*>(data.m_Outputs[0])->GetTensor();
44 arm_compute::DataType computeDataType =
45 boost::polymorphic_downcast<NeonTensorHandle*>(data.m_Outputs[0])->GetDataType();
46
47 switch (computeDataType)
48 {
49 case arm_compute::DataType::F16:
50 {
51 CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<Half>(), output);
52 break;
53 }
54 case arm_compute::DataType::F32:
55 {
56 CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<float>(), output);
57 break;
58 }
59 case arm_compute::DataType::QASYMM8:
60 {
61 CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<uint8_t>(), output);
62 break;
63 }
64 default:
65 {
66 BOOST_ASSERT_MSG(false, "Unknown data type");
67 break;
68 }
69 }
70
71 m_RanOnce = true;
72 }
73}
74
75} //namespace armnn