blob: d6901d12259aa64f05846328c3e15f2c3abe5b61 [file] [log] [blame]
David Monahan8a570462023-11-22 13:24:25 +00001//
2// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include <aclCommon/ArmComputeTensorHandle.hpp>
8#include <aclCommon/ArmComputeTensorUtils.hpp>
9
10#include <armnn/utility/PolymorphicDowncast.hpp>
11#include <Half.hpp>
12
13#include <arm_compute/runtime/CL/CLTensor.h>
14#include <arm_compute/runtime/CL/CLSubTensor.h>
15#include <arm_compute/runtime/IMemoryGroup.h>
16#include <arm_compute/runtime/MemoryGroup.h>
17#include <arm_compute/core/TensorShape.h>
18#include <arm_compute/core/Coordinates.h>
19
20#include <aclCommon/IClTensorHandle.hpp>
21
22namespace armnn
23{
24
25class GpuFsaTensorHandle : public IClTensorHandle
26{
27public:
28 GpuFsaTensorHandle(const TensorInfo& tensorInfo)
29 : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined)),
30 m_Imported(false),
31 m_IsImportEnabled(false)
32 {
33 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
34 }
35
36 GpuFsaTensorHandle(const TensorInfo& tensorInfo,
37 DataLayout dataLayout,
38 MemorySourceFlags importFlags = static_cast<MemorySourceFlags>(MemorySource::Undefined))
39 : m_ImportFlags(importFlags),
40 m_Imported(false),
41 m_IsImportEnabled(false)
42 {
43 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
44 }
45
46 arm_compute::CLTensor& GetTensor() override { return m_Tensor; }
47 arm_compute::CLTensor const& GetTensor() const override { return m_Tensor; }
48 virtual void Allocate() override
49 {
50 // If we have enabled Importing, don't allocate the tensor
51 if (m_IsImportEnabled)
52 {
53 throw MemoryImportException("GpuFsaTensorHandle::Attempting to allocate memory when importing");
54 }
55 else
56 {
57 armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
58 }
59
60 }
61
62 virtual void Manage() override
63 {
64 // If we have enabled Importing, don't manage the tensor
65 if (m_IsImportEnabled)
66 {
67 throw MemoryImportException("GpuFsaTensorHandle::Attempting to manage memory when importing");
68 }
69 else
70 {
71 assert(m_MemoryGroup != nullptr);
72 m_MemoryGroup->manage(&m_Tensor);
73 }
74 }
75
76 virtual const void* Map(bool blocking = true) const override
77 {
78 const_cast<arm_compute::CLTensor*>(&m_Tensor)->map(blocking);
79 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
80 }
81
82 virtual void Unmap() const override { const_cast<arm_compute::CLTensor*>(&m_Tensor)->unmap(); }
83
84 virtual ITensorHandle* GetParent() const override { return nullptr; }
85
86 virtual arm_compute::DataType GetDataType() const override
87 {
88 return m_Tensor.info()->data_type();
89 }
90
91 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
92 {
93 m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup);
94 }
95
96 TensorShape GetStrides() const override
97 {
98 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
99 }
100
101 TensorShape GetShape() const override
102 {
103 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
104 }
105
106 void SetImportFlags(MemorySourceFlags importFlags)
107 {
108 m_ImportFlags = importFlags;
109 }
110
111 MemorySourceFlags GetImportFlags() const override
112 {
113 return m_ImportFlags;
114 }
115
116 void SetImportEnabledFlag(bool importEnabledFlag)
117 {
118 m_IsImportEnabled = importEnabledFlag;
119 }
120
121 virtual bool Import(void* /*memory*/, MemorySource source) override
122 {
123 if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
124 {
125 throw MemoryImportException("GpuFsaTensorHandle::Incorrect import flag");
126 }
127 m_Imported = false;
128 return false;
129 }
130
131 virtual bool CanBeImported(void* /*memory*/, MemorySource /*source*/) override
132 {
133 // This TensorHandle can never import.
134 return false;
135 }
136
137private:
138 // Only used for testing
139 void CopyOutTo(void* memory) const override
140 {
141 const_cast<armnn::GpuFsaTensorHandle*>(this)->Map(true);
142 switch(this->GetDataType())
143 {
144 case arm_compute::DataType::F32:
145 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
146 static_cast<float*>(memory));
147 break;
148 case arm_compute::DataType::U8:
149 case arm_compute::DataType::QASYMM8:
150 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
151 static_cast<uint8_t*>(memory));
152 break;
153 case arm_compute::DataType::QSYMM8:
154 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
155 case arm_compute::DataType::QASYMM8_SIGNED:
156 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
157 static_cast<int8_t*>(memory));
158 break;
159 case arm_compute::DataType::F16:
160 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
161 static_cast<armnn::Half*>(memory));
162 break;
163 case arm_compute::DataType::S16:
164 case arm_compute::DataType::QSYMM16:
165 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
166 static_cast<int16_t*>(memory));
167 break;
168 case arm_compute::DataType::S32:
169 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
170 static_cast<int32_t*>(memory));
171 break;
172 default:
173 {
174 throw armnn::UnimplementedException();
175 }
176 }
177 const_cast<armnn::GpuFsaTensorHandle*>(this)->Unmap();
178 }
179
180 // Only used for testing
181 void CopyInFrom(const void* memory) override
182 {
183 this->Map(true);
184 switch(this->GetDataType())
185 {
186 case arm_compute::DataType::F32:
187 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
188 this->GetTensor());
189 break;
190 case arm_compute::DataType::U8:
191 case arm_compute::DataType::QASYMM8:
192 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
193 this->GetTensor());
194 break;
195 case arm_compute::DataType::F16:
196 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
197 this->GetTensor());
198 break;
199 case arm_compute::DataType::S16:
200 case arm_compute::DataType::QSYMM8:
201 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
202 case arm_compute::DataType::QASYMM8_SIGNED:
203 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
204 this->GetTensor());
205 break;
206 case arm_compute::DataType::QSYMM16:
207 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
208 this->GetTensor());
209 break;
210 case arm_compute::DataType::S32:
211 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
212 this->GetTensor());
213 break;
214 default:
215 {
216 throw armnn::UnimplementedException();
217 }
218 }
219 this->Unmap();
220 }
221
222 arm_compute::CLTensor m_Tensor;
223 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
224 MemorySourceFlags m_ImportFlags;
225 bool m_Imported;
226 bool m_IsImportEnabled;
227};
228
229class GpuFsaSubTensorHandle : public IClTensorHandle
230{
231public:
232 GpuFsaSubTensorHandle(IClTensorHandle* parent,
233 const arm_compute::TensorShape& shape,
234 const arm_compute::Coordinates& coords)
235 : m_Tensor(&parent->GetTensor(), shape, coords)
236 {
237 parentHandle = parent;
238 }
239
240 arm_compute::CLSubTensor& GetTensor() override { return m_Tensor; }
241 arm_compute::CLSubTensor const& GetTensor() const override { return m_Tensor; }
242
243 virtual void Allocate() override {}
244 virtual void Manage() override {}
245
246 virtual const void* Map(bool blocking = true) const override
247 {
248 const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->map(blocking);
249 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
250 }
251 virtual void Unmap() const override { const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->unmap(); }
252
253 virtual ITensorHandle* GetParent() const override { return parentHandle; }
254
255 virtual arm_compute::DataType GetDataType() const override
256 {
257 return m_Tensor.info()->data_type();
258 }
259
260 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
261
262 TensorShape GetStrides() const override
263 {
264 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
265 }
266
267 TensorShape GetShape() const override
268 {
269 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
270 }
271
272private:
273 // Only used for testing
274 void CopyOutTo(void* memory) const override
275 {
276 const_cast<GpuFsaSubTensorHandle*>(this)->Map(true);
277 switch(this->GetDataType())
278 {
279 case arm_compute::DataType::F32:
280 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
281 static_cast<float*>(memory));
282 break;
283 case arm_compute::DataType::U8:
284 case arm_compute::DataType::QASYMM8:
285 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
286 static_cast<uint8_t*>(memory));
287 break;
288 case arm_compute::DataType::F16:
289 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
290 static_cast<armnn::Half*>(memory));
291 break;
292 case arm_compute::DataType::QSYMM8:
293 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
294 case arm_compute::DataType::QASYMM8_SIGNED:
295 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
296 static_cast<int8_t*>(memory));
297 break;
298 case arm_compute::DataType::S16:
299 case arm_compute::DataType::QSYMM16:
300 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
301 static_cast<int16_t*>(memory));
302 break;
303 case arm_compute::DataType::S32:
304 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
305 static_cast<int32_t*>(memory));
306 break;
307 default:
308 {
309 throw armnn::UnimplementedException();
310 }
311 }
312 const_cast<GpuFsaSubTensorHandle*>(this)->Unmap();
313 }
314
315 // Only used for testing
316 void CopyInFrom(const void* memory) override
317 {
318 this->Map(true);
319 switch(this->GetDataType())
320 {
321 case arm_compute::DataType::F32:
322 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
323 this->GetTensor());
324 break;
325 case arm_compute::DataType::U8:
326 case arm_compute::DataType::QASYMM8:
327 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
328 this->GetTensor());
329 break;
330 case arm_compute::DataType::F16:
331 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
332 this->GetTensor());
333 break;
334 case arm_compute::DataType::QSYMM8:
335 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
336 case arm_compute::DataType::QASYMM8_SIGNED:
337 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
338 this->GetTensor());
339 break;
340 case arm_compute::DataType::S16:
341 case arm_compute::DataType::QSYMM16:
342 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
343 this->GetTensor());
344 break;
345 case arm_compute::DataType::S32:
346 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
347 this->GetTensor());
348 break;
349 default:
350 {
351 throw armnn::UnimplementedException();
352 }
353 }
354 this->Unmap();
355 }
356
357 mutable arm_compute::CLSubTensor m_Tensor;
358 ITensorHandle* parentHandle = nullptr;
359};
360
361} // namespace armnn