blob: a68d3e635f0ecf6c3d7592c1836506b78072ad30 [file] [log] [blame]
David Monahanbd738082023-12-08 12:50:02 +00001//
2// Copyright © 2024 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "GpuFsaConstantWorkload.hpp"
7#include "GpuFsaWorkloadUtils.hpp"
8
9#include <Half.hpp>
10#include <aclCommon/ArmComputeTensorUtils.hpp>
11#include <gpuFsa/GpuFsaTensorHandle.hpp>
12#include <armnn/backends/TensorHandle.hpp>
13
14namespace armnn
15{
16
17arm_compute::Status GpuFsaConstantWorkloadValidate(const TensorInfo& output)
18{
19 const arm_compute::TensorInfo neonOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output);
20
21 std::array<arm_compute::DataType,8> supportedTypes = {
22 arm_compute::DataType::F16,
23 arm_compute::DataType::F32,
24 arm_compute::DataType::QASYMM8,
25 arm_compute::DataType::QASYMM8_SIGNED,
26 arm_compute::DataType::QSYMM16,
27 arm_compute::DataType::QSYMM8,
28 arm_compute::DataType::QSYMM8_PER_CHANNEL,
29 arm_compute::DataType::S32
30 };
31 auto it = std::find(begin(supportedTypes), end(supportedTypes), neonOutputInfo.data_type());
32
33 if (it != end(supportedTypes))
34 {
35 return arm_compute::Status{};
36 }
37 else
38 {
39 return arm_compute::Status{arm_compute::ErrorCode::RUNTIME_ERROR, "Unsupported DataType"};
40 }
41}
42
43GpuFsaConstantWorkload::GpuFsaConstantWorkload(const ConstantQueueDescriptor& descriptor,
44 const WorkloadInfo& info,
45 const arm_compute::CLCompileContext&)
46 : GpuFsaBaseWorkload<ConstantQueueDescriptor>(descriptor, info)
47 , m_RanOnce(false)
48{
49}
50
51void GpuFsaConstantWorkload::Execute() const
52{
53 // The intermediate tensor held by the corresponding layer output handler can be initialised with the given data
54 // on the first inference, then reused for subsequent inferences.
55 // The initialisation cannot happen at workload construction time since the ACL kernel for the next layer may not
56 // have been configured at the time.
57 if (!m_RanOnce)
58 {
59 const ConstantQueueDescriptor& data = this->m_Data;
David Monahanbd738082023-12-08 12:50:02 +000060 arm_compute::CLTensor& output = static_cast<GpuFsaTensorHandle*>(data.m_Outputs[0])->GetTensor();
61 arm_compute::DataType computeDataType = static_cast<GpuFsaTensorHandle*>(data.m_Outputs[0])->GetDataType();
62
63 switch (computeDataType)
64 {
65 case arm_compute::DataType::F16:
66 {
67 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<Half>());
68 break;
69 }
70 case arm_compute::DataType::F32:
71 {
72 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<float>());
73 break;
74 }
75 case arm_compute::DataType::QASYMM8:
76 {
77 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<uint8_t>());
78 break;
79 }
80 case arm_compute::DataType::QASYMM8_SIGNED:
81 {
82 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<int8_t>());
83 break;
84 }
85 case arm_compute::DataType::QSYMM16:
86 {
87 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<int16_t>());
88 break;
89 }
90 case arm_compute::DataType::QSYMM8:
91 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
92 {
93 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<int8_t>());
94 break;
95 }
96 case arm_compute::DataType::S32:
97 {
98 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<int32_t>());
99 break;
100 }
101 default:
102 {
Colm Donelanb4ef1632024-02-01 15:00:43 +0000103 throw InvalidArgumentException("Unknown data type passed to GpuFsaConstantWorkload::Execute()");
David Monahanbd738082023-12-08 12:50:02 +0000104 break;
105 }
106 }
107
108 m_RanOnce = true;
109 }
110}
111
112} //namespace armnn