blob: d6b5c57a7ef3fa1828623b1ee610ee4721835a3d [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
Nattapat Chaimanowong55b1cda2018-10-10 14:51:27 +01006#include "ClConstantWorkload.hpp"
arovir01616e7752018-10-01 17:08:59 +01007
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00008#include <Half.hpp>
9#include <aclCommon/ArmComputeTensorUtils.hpp>
10#include <cl/ClTensorHandle.hpp>
11#include <backendsCommon/CpuTensorHandle.hpp>
telsoa014fcda012018-03-09 14:13:49 +000012
Matthew Bentham14e46692018-09-20 15:35:30 +010013#include "ClWorkloadUtils.hpp"
14
telsoa014fcda012018-03-09 14:13:49 +000015namespace armnn
16{
17
Mike Kelly0886ac42020-04-27 09:55:40 +010018arm_compute::Status ClConstantWorkloadValidate(const TensorInfo& output)
19{
20 const arm_compute::TensorInfo neonOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output);
21
Teresa Charlin9ad2e5b2020-04-10 22:34:48 +010022 std::array<arm_compute::DataType,8> supportedTypes = {
Mike Kelly0886ac42020-04-27 09:55:40 +010023 arm_compute::DataType::F16,
24 arm_compute::DataType::F32,
25 arm_compute::DataType::QASYMM8,
26 arm_compute::DataType::QASYMM8_SIGNED,
27 arm_compute::DataType::QSYMM16,
28 arm_compute::DataType::QSYMM8,
Teresa Charlin9ad2e5b2020-04-10 22:34:48 +010029 arm_compute::DataType::QSYMM8_PER_CHANNEL,
30 arm_compute::DataType::S32
Mike Kelly0886ac42020-04-27 09:55:40 +010031 };
32 auto it = std::find(begin(supportedTypes), end(supportedTypes), neonOutputInfo.data_type());
33
34 if (it != end(supportedTypes))
35 {
36 return arm_compute::Status{};
37 }
38 else
39 {
40 return arm_compute::Status{arm_compute::ErrorCode::RUNTIME_ERROR, "Unsupported DataType"};
41 }
42}
43
Nattapat Chaimanowong55b1cda2018-10-10 14:51:27 +010044ClConstantWorkload::ClConstantWorkload(const ConstantQueueDescriptor& descriptor, const WorkloadInfo& info)
45 : BaseWorkload<ConstantQueueDescriptor>(descriptor, info)
46 , m_RanOnce(false)
telsoa014fcda012018-03-09 14:13:49 +000047{
Nattapat Chaimanowong55b1cda2018-10-10 14:51:27 +010048}
49
50void ClConstantWorkload::Execute() const
51{
52 ARMNN_SCOPED_PROFILING_EVENT_CL("ClConstantWorkload_Execute");
53
telsoa014fcda012018-03-09 14:13:49 +000054 // The intermediate tensor held by the corresponding layer output handler can be initialised with the given data
55 // on the first inference, then reused for subsequent inferences.
56 // The initialisation cannot happen at workload construction time since the ACL kernel for the next layer may not
57 // have been configured at the time.
58 if (!m_RanOnce)
59 {
60 const ConstantQueueDescriptor& data = this->m_Data;
61
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010062 ARMNN_ASSERT(data.m_LayerOutput != nullptr);
telsoa014fcda012018-03-09 14:13:49 +000063 arm_compute::CLTensor& output = static_cast<ClTensorHandle*>(data.m_Outputs[0])->GetTensor();
telsoa01c577f2c2018-08-31 09:22:23 +010064 arm_compute::DataType computeDataType = static_cast<ClTensorHandle*>(data.m_Outputs[0])->GetDataType();
telsoa014fcda012018-03-09 14:13:49 +000065
telsoa01c577f2c2018-08-31 09:22:23 +010066 switch (computeDataType)
telsoa014fcda012018-03-09 14:13:49 +000067 {
telsoa01c577f2c2018-08-31 09:22:23 +010068 case arm_compute::DataType::F16:
69 {
Matthew Benthamca6616c2018-09-21 15:16:53 +010070 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<Half>());
telsoa01c577f2c2018-08-31 09:22:23 +010071 break;
72 }
73 case arm_compute::DataType::F32:
telsoa014fcda012018-03-09 14:13:49 +000074 {
Matthew Benthamca6616c2018-09-21 15:16:53 +010075 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<float>());
telsoa014fcda012018-03-09 14:13:49 +000076 break;
77 }
telsoa01c577f2c2018-08-31 09:22:23 +010078 case arm_compute::DataType::QASYMM8:
telsoa014fcda012018-03-09 14:13:49 +000079 {
Matthew Benthamca6616c2018-09-21 15:16:53 +010080 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<uint8_t>());
telsoa014fcda012018-03-09 14:13:49 +000081 break;
82 }
Mike Kelly0886ac42020-04-27 09:55:40 +010083 case arm_compute::DataType::QASYMM8_SIGNED:
84 {
85 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<int8_t>());
86 break;
87 }
88 case arm_compute::DataType::QSYMM16:
89 {
90 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<int16_t>());
91 break;
92 }
93 case arm_compute::DataType::QSYMM8:
94 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
95 {
96 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<int8_t>());
97 break;
98 }
Teresa Charlin9ad2e5b2020-04-10 22:34:48 +010099 case arm_compute::DataType::S32:
100 {
101 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<int32_t>());
102 break;
103 }
telsoa014fcda012018-03-09 14:13:49 +0000104 default:
105 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100106 ARMNN_ASSERT_MSG(false, "Unknown data type");
telsoa014fcda012018-03-09 14:13:49 +0000107 break;
108 }
109 }
110
111 m_RanOnce = true;
112 }
113}
114
Matthew Bentham14e46692018-09-20 15:35:30 +0100115} //namespace armnn