blob: 56af3b3baaef45bfe71b0d0823831a96e35ce6dd [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
6#include "ClBaseConstantWorkload.hpp"
David Beck711fa312018-09-24 10:46:38 +01007#include <backends/aclCommon/ArmComputeTensorUtils.hpp>
8#include <backends/ClTensorHandle.hpp>
9#include <backends/CpuTensorHandle.hpp>
10#include <Half.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011
Matthew Bentham14e46692018-09-20 15:35:30 +010012#include "ClWorkloadUtils.hpp"
13
telsoa014fcda012018-03-09 14:13:49 +000014namespace armnn
15{
16
telsoa01c577f2c2018-08-31 09:22:23 +010017template class ClBaseConstantWorkload<DataType::Float16, DataType::Float32>;
telsoa014fcda012018-03-09 14:13:49 +000018template class ClBaseConstantWorkload<DataType::QuantisedAsymm8>;
19
telsoa01c577f2c2018-08-31 09:22:23 +010020template<armnn::DataType... dataTypes>
21void ClBaseConstantWorkload<dataTypes...>::Execute() const
telsoa014fcda012018-03-09 14:13:49 +000022{
23 // The intermediate tensor held by the corresponding layer output handler can be initialised with the given data
24 // on the first inference, then reused for subsequent inferences.
25 // The initialisation cannot happen at workload construction time since the ACL kernel for the next layer may not
26 // have been configured at the time.
27 if (!m_RanOnce)
28 {
29 const ConstantQueueDescriptor& data = this->m_Data;
30
31 BOOST_ASSERT(data.m_LayerOutput != nullptr);
32 arm_compute::CLTensor& output = static_cast<ClTensorHandle*>(data.m_Outputs[0])->GetTensor();
telsoa01c577f2c2018-08-31 09:22:23 +010033 arm_compute::DataType computeDataType = static_cast<ClTensorHandle*>(data.m_Outputs[0])->GetDataType();
telsoa014fcda012018-03-09 14:13:49 +000034
telsoa01c577f2c2018-08-31 09:22:23 +010035 switch (computeDataType)
telsoa014fcda012018-03-09 14:13:49 +000036 {
telsoa01c577f2c2018-08-31 09:22:23 +010037 case arm_compute::DataType::F16:
38 {
Matthew Benthamca6616c2018-09-21 15:16:53 +010039 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<Half>());
telsoa01c577f2c2018-08-31 09:22:23 +010040 break;
41 }
42 case arm_compute::DataType::F32:
telsoa014fcda012018-03-09 14:13:49 +000043 {
Matthew Benthamca6616c2018-09-21 15:16:53 +010044 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<float>());
telsoa014fcda012018-03-09 14:13:49 +000045 break;
46 }
telsoa01c577f2c2018-08-31 09:22:23 +010047 case arm_compute::DataType::QASYMM8:
telsoa014fcda012018-03-09 14:13:49 +000048 {
Matthew Benthamca6616c2018-09-21 15:16:53 +010049 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<uint8_t>());
telsoa014fcda012018-03-09 14:13:49 +000050 break;
51 }
52 default:
53 {
54 BOOST_ASSERT_MSG(false, "Unknown data type");
55 break;
56 }
57 }
58
59 m_RanOnce = true;
60 }
61}
62
63
Matthew Bentham14e46692018-09-20 15:35:30 +010064} //namespace armnn