blob: 39ae14eaf3f472b7652b15b20918333ca72f1f01 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
Nattapat Chaimanowong55b1cda2018-10-10 14:51:27 +01006#include "ClConstantWorkload.hpp"
arovir01616e7752018-10-01 17:08:59 +01007
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00008#include <Half.hpp>
9#include <aclCommon/ArmComputeTensorUtils.hpp>
10#include <cl/ClTensorHandle.hpp>
11#include <backendsCommon/CpuTensorHandle.hpp>
telsoa014fcda012018-03-09 14:13:49 +000012
Matthew Bentham14e46692018-09-20 15:35:30 +010013#include "ClWorkloadUtils.hpp"
14
telsoa014fcda012018-03-09 14:13:49 +000015namespace armnn
16{
17
Nattapat Chaimanowong55b1cda2018-10-10 14:51:27 +010018ClConstantWorkload::ClConstantWorkload(const ConstantQueueDescriptor& descriptor, const WorkloadInfo& info)
19 : BaseWorkload<ConstantQueueDescriptor>(descriptor, info)
20 , m_RanOnce(false)
telsoa014fcda012018-03-09 14:13:49 +000021{
Nattapat Chaimanowong55b1cda2018-10-10 14:51:27 +010022}
23
24void ClConstantWorkload::Execute() const
25{
26 ARMNN_SCOPED_PROFILING_EVENT_CL("ClConstantWorkload_Execute");
27
telsoa014fcda012018-03-09 14:13:49 +000028 // The intermediate tensor held by the corresponding layer output handler can be initialised with the given data
29 // on the first inference, then reused for subsequent inferences.
30 // The initialisation cannot happen at workload construction time since the ACL kernel for the next layer may not
31 // have been configured at the time.
32 if (!m_RanOnce)
33 {
34 const ConstantQueueDescriptor& data = this->m_Data;
35
36 BOOST_ASSERT(data.m_LayerOutput != nullptr);
37 arm_compute::CLTensor& output = static_cast<ClTensorHandle*>(data.m_Outputs[0])->GetTensor();
telsoa01c577f2c2018-08-31 09:22:23 +010038 arm_compute::DataType computeDataType = static_cast<ClTensorHandle*>(data.m_Outputs[0])->GetDataType();
telsoa014fcda012018-03-09 14:13:49 +000039
telsoa01c577f2c2018-08-31 09:22:23 +010040 switch (computeDataType)
telsoa014fcda012018-03-09 14:13:49 +000041 {
telsoa01c577f2c2018-08-31 09:22:23 +010042 case arm_compute::DataType::F16:
43 {
Matthew Benthamca6616c2018-09-21 15:16:53 +010044 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<Half>());
telsoa01c577f2c2018-08-31 09:22:23 +010045 break;
46 }
47 case arm_compute::DataType::F32:
telsoa014fcda012018-03-09 14:13:49 +000048 {
Matthew Benthamca6616c2018-09-21 15:16:53 +010049 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<float>());
telsoa014fcda012018-03-09 14:13:49 +000050 break;
51 }
telsoa01c577f2c2018-08-31 09:22:23 +010052 case arm_compute::DataType::QASYMM8:
telsoa014fcda012018-03-09 14:13:49 +000053 {
Matthew Benthamca6616c2018-09-21 15:16:53 +010054 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<uint8_t>());
telsoa014fcda012018-03-09 14:13:49 +000055 break;
56 }
57 default:
58 {
59 BOOST_ASSERT_MSG(false, "Unknown data type");
60 break;
61 }
62 }
63
64 m_RanOnce = true;
65 }
66}
67
Matthew Bentham14e46692018-09-20 15:35:30 +010068} //namespace armnn