blob: 39d3c0ddab8f85b6d851e3e787f4d09f2f4c1c86 [file] [log] [blame]
David Monahanbd738082023-12-08 12:50:02 +00001//
2// Copyright © 2024 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "GpuFsaConstantWorkload.hpp"
7#include "GpuFsaWorkloadUtils.hpp"
8
9#include <Half.hpp>
10#include <aclCommon/ArmComputeTensorUtils.hpp>
11#include <gpuFsa/GpuFsaTensorHandle.hpp>
12#include <armnn/backends/TensorHandle.hpp>
13
14namespace armnn
15{
16
17arm_compute::Status GpuFsaConstantWorkloadValidate(const TensorInfo& output)
18{
19 const arm_compute::TensorInfo neonOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output);
20
21 std::array<arm_compute::DataType,8> supportedTypes = {
22 arm_compute::DataType::F16,
23 arm_compute::DataType::F32,
24 arm_compute::DataType::QASYMM8,
25 arm_compute::DataType::QASYMM8_SIGNED,
26 arm_compute::DataType::QSYMM16,
27 arm_compute::DataType::QSYMM8,
28 arm_compute::DataType::QSYMM8_PER_CHANNEL,
29 arm_compute::DataType::S32
30 };
31 auto it = std::find(begin(supportedTypes), end(supportedTypes), neonOutputInfo.data_type());
32
33 if (it != end(supportedTypes))
34 {
35 return arm_compute::Status{};
36 }
37 else
38 {
39 return arm_compute::Status{arm_compute::ErrorCode::RUNTIME_ERROR, "Unsupported DataType"};
40 }
41}
42
43GpuFsaConstantWorkload::GpuFsaConstantWorkload(const ConstantQueueDescriptor& descriptor,
44 const WorkloadInfo& info,
45 const arm_compute::CLCompileContext&)
46 : GpuFsaBaseWorkload<ConstantQueueDescriptor>(descriptor, info)
47 , m_RanOnce(false)
48{
49}
50
51void GpuFsaConstantWorkload::Execute() const
52{
53 // The intermediate tensor held by the corresponding layer output handler can be initialised with the given data
54 // on the first inference, then reused for subsequent inferences.
55 // The initialisation cannot happen at workload construction time since the ACL kernel for the next layer may not
56 // have been configured at the time.
57 if (!m_RanOnce)
58 {
59 const ConstantQueueDescriptor& data = this->m_Data;
60
61 ARMNN_ASSERT(data.m_LayerOutput != nullptr);
62 arm_compute::CLTensor& output = static_cast<GpuFsaTensorHandle*>(data.m_Outputs[0])->GetTensor();
63 arm_compute::DataType computeDataType = static_cast<GpuFsaTensorHandle*>(data.m_Outputs[0])->GetDataType();
64
65 switch (computeDataType)
66 {
67 case arm_compute::DataType::F16:
68 {
69 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<Half>());
70 break;
71 }
72 case arm_compute::DataType::F32:
73 {
74 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<float>());
75 break;
76 }
77 case arm_compute::DataType::QASYMM8:
78 {
79 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<uint8_t>());
80 break;
81 }
82 case arm_compute::DataType::QASYMM8_SIGNED:
83 {
84 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<int8_t>());
85 break;
86 }
87 case arm_compute::DataType::QSYMM16:
88 {
89 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<int16_t>());
90 break;
91 }
92 case arm_compute::DataType::QSYMM8:
93 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
94 {
95 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<int8_t>());
96 break;
97 }
98 case arm_compute::DataType::S32:
99 {
100 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<int32_t>());
101 break;
102 }
103 default:
104 {
105 ARMNN_ASSERT_MSG(false, "Unknown data type");
106 break;
107 }
108 }
109
110 m_RanOnce = true;
111 }
112}
113
114} //namespace armnn