blob: a03a4e9ea628ba01259a10b15379bc7423602e9a [file] [log] [blame]
David Monahane4a41dc2021-04-14 16:55:36 +01001//
Cathal Corbettd9e55f02023-01-11 13:03:21 +00002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
David Monahane4a41dc2021-04-14 16:55:36 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <aclCommon/ArmComputeTensorHandle.hpp>
9#include <aclCommon/ArmComputeTensorUtils.hpp>
10
11#include <Half.hpp>
12
13#include <armnn/utility/PolymorphicDowncast.hpp>
14
15#include <arm_compute/runtime/CL/CLTensor.h>
16#include <arm_compute/runtime/CL/CLSubTensor.h>
17#include <arm_compute/runtime/IMemoryGroup.h>
18#include <arm_compute/runtime/MemoryGroup.h>
19#include <arm_compute/core/TensorShape.h>
20#include <arm_compute/core/Coordinates.h>
21
Cathal Corbettd9e55f02023-01-11 13:03:21 +000022#include <aclCommon/IClTensorHandle.hpp>
Narumol Prangnawarat9ef36142022-01-25 15:15:34 +000023
Francis Murtaghe73eda92021-05-21 13:36:54 +010024#include <CL/cl_ext.h>
David Monahane4a41dc2021-04-14 16:55:36 +010025#include <arm_compute/core/CL/CLKernelLibrary.h>
26
27namespace armnn
28{
29
Narumol Prangnawarat9ef36142022-01-25 15:15:34 +000030class ClImportTensorHandle : public IClTensorHandle
David Monahane4a41dc2021-04-14 16:55:36 +010031{
32public:
33 ClImportTensorHandle(const TensorInfo& tensorInfo, MemorySourceFlags importFlags)
34 : m_ImportFlags(importFlags)
35 {
36 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
37 }
38
39 ClImportTensorHandle(const TensorInfo& tensorInfo,
40 DataLayout dataLayout,
41 MemorySourceFlags importFlags)
David Monahan6642b8a2021-11-04 16:31:46 +000042 : m_ImportFlags(importFlags), m_Imported(false)
David Monahane4a41dc2021-04-14 16:55:36 +010043 {
44 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
45 }
46
47 arm_compute::CLTensor& GetTensor() override { return m_Tensor; }
48 arm_compute::CLTensor const& GetTensor() const override { return m_Tensor; }
49 virtual void Allocate() override {}
50 virtual void Manage() override {}
51
52 virtual const void* Map(bool blocking = true) const override
53 {
54 IgnoreUnused(blocking);
55 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
56 }
57
58 virtual void Unmap() const override {}
59
60 virtual ITensorHandle* GetParent() const override { return nullptr; }
61
62 virtual arm_compute::DataType GetDataType() const override
63 {
64 return m_Tensor.info()->data_type();
65 }
66
67 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
68 {
69 IgnoreUnused(memoryGroup);
70 }
71
72 TensorShape GetStrides() const override
73 {
74 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
75 }
76
77 TensorShape GetShape() const override
78 {
79 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
80 }
81
82 void SetImportFlags(MemorySourceFlags importFlags)
83 {
84 m_ImportFlags = importFlags;
85 }
86
87 MemorySourceFlags GetImportFlags() const override
88 {
89 return m_ImportFlags;
90 }
91
92 virtual bool Import(void* memory, MemorySource source) override
93 {
94 if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
95 {
96 if (source == MemorySource::Malloc)
97 {
David Monahane4a41dc2021-04-14 16:55:36 +010098 const cl_import_properties_arm importProperties[] =
99 {
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100100 CL_IMPORT_TYPE_ARM,
101 CL_IMPORT_TYPE_HOST_ARM,
102 0
David Monahane4a41dc2021-04-14 16:55:36 +0100103 };
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100104 return ClImport(importProperties, memory);
105 }
106 if (source == MemorySource::DmaBuf)
107 {
108 const cl_import_properties_arm importProperties[] =
David Monahane4a41dc2021-04-14 16:55:36 +0100109 {
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100110 CL_IMPORT_TYPE_ARM,
111 CL_IMPORT_TYPE_DMA_BUF_ARM,
Francis Murtaghf5d5e6c2021-07-26 13:19:33 +0100112 CL_IMPORT_DMA_BUF_DATA_CONSISTENCY_WITH_HOST_ARM,
113 CL_TRUE,
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100114 0
115 };
David Monahane4a41dc2021-04-14 16:55:36 +0100116
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100117 return ClImport(importProperties, memory);
David Monahane4a41dc2021-04-14 16:55:36 +0100118
David Monahane4a41dc2021-04-14 16:55:36 +0100119 }
Francis Murtagh9db96e02021-08-13 16:15:09 +0100120 if (source == MemorySource::DmaBufProtected)
121 {
122 const cl_import_properties_arm importProperties[] =
123 {
124 CL_IMPORT_TYPE_ARM,
125 CL_IMPORT_TYPE_DMA_BUF_ARM,
126 CL_IMPORT_TYPE_PROTECTED_ARM,
127 CL_TRUE,
128 0
129 };
130
131 return ClImport(importProperties, memory, true);
132
133 }
David Monahan6642b8a2021-11-04 16:31:46 +0000134 // Case for importing memory allocated by OpenCl externally directly into the tensor
135 else if (source == MemorySource::Gralloc)
136 {
137 // m_Tensor not yet Allocated
138 if (!m_Imported && !m_Tensor.buffer())
139 {
140 // Importing memory allocated by OpenCl into the tensor directly.
141 arm_compute::Status status =
142 m_Tensor.allocator()->import_memory(cl::Buffer(static_cast<cl_mem>(memory)));
143 m_Imported = bool(status);
144 if (!m_Imported)
145 {
146 throw MemoryImportException(status.error_description());
147 }
148 return m_Imported;
149 }
150
151 // m_Tensor.buffer() initially allocated with Allocate().
152 else if (!m_Imported && m_Tensor.buffer())
153 {
154 throw MemoryImportException(
155 "ClImportTensorHandle::Import Attempting to import on an already allocated tensor");
156 }
157
158 // m_Tensor.buffer() previously imported.
159 else if (m_Imported)
160 {
161 // Importing memory allocated by OpenCl into the tensor directly.
162 arm_compute::Status status =
163 m_Tensor.allocator()->import_memory(cl::Buffer(static_cast<cl_mem>(memory)));
164 m_Imported = bool(status);
165 if (!m_Imported)
166 {
167 throw MemoryImportException(status.error_description());
168 }
169 return m_Imported;
170 }
171 else
172 {
173 throw MemoryImportException("ClImportTensorHandle::Failed to Import Gralloc Memory");
174 }
175 }
David Monahane4a41dc2021-04-14 16:55:36 +0100176 else
177 {
178 throw MemoryImportException("ClImportTensorHandle::Import flag is not supported");
179 }
180 }
181 else
182 {
183 throw MemoryImportException("ClImportTensorHandle::Incorrect import flag");
184 }
David Monahane4a41dc2021-04-14 16:55:36 +0100185 }
186
Sadik Armagana045ac02022-07-01 14:32:05 +0100187 virtual bool CanBeImported(void* /*memory*/, MemorySource source) override
Nikhil Raj60ab9762022-01-13 09:34:44 +0000188 {
189 if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
190 {
191 if (source == MemorySource::Malloc)
192 {
Sadik Armagana045ac02022-07-01 14:32:05 +0100193 // Returning true as ClImport() function will decide if memory can be imported or not
194 return true;
Nikhil Raj60ab9762022-01-13 09:34:44 +0000195 }
196 }
197 else
198 {
199 throw MemoryImportException("ClImportTensorHandle::Incorrect import flag");
200 }
201 return false;
202 }
203
David Monahane4a41dc2021-04-14 16:55:36 +0100204private:
Francis Murtagh9db96e02021-08-13 16:15:09 +0100205 bool ClImport(const cl_import_properties_arm* importProperties, void* memory, bool isProtected = false)
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100206 {
Jan Eilersc1c872f2021-07-22 13:17:04 +0100207 size_t totalBytes = m_Tensor.info()->total_size();
208
Nikhil Raj60ab9762022-01-13 09:34:44 +0000209 // Round the size of the mapping to match the CL_DEVICE_GLOBAL_MEM_CACHELINE_SIZE
210 // This does not change the size of the buffer, only the size of the mapping the buffer is mapped to
Jan Eilersc1c872f2021-07-22 13:17:04 +0100211 auto cachelineAlignment =
212 arm_compute::CLKernelLibrary::get().get_device().getInfo<CL_DEVICE_GLOBAL_MEM_CACHELINE_SIZE>();
Narumol Prangnawarate2af6f42022-01-28 17:59:18 +0000213 auto roundedSize = totalBytes;
214 if (totalBytes % cachelineAlignment != 0)
215 {
216 roundedSize = cachelineAlignment + totalBytes - (totalBytes % cachelineAlignment);
217 }
Jan Eilersc1c872f2021-07-22 13:17:04 +0100218
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100219 cl_int error = CL_SUCCESS;
Francis Murtagh9db96e02021-08-13 16:15:09 +0100220 cl_mem buffer;
221 if (isProtected)
222 {
223 buffer = clImportMemoryARM(arm_compute::CLKernelLibrary::get().context().get(),
224 CL_MEM_HOST_NO_ACCESS, importProperties, memory, roundedSize, &error);
225 }
226 else
227 {
228 buffer = clImportMemoryARM(arm_compute::CLKernelLibrary::get().context().get(),
229 CL_MEM_READ_WRITE, importProperties, memory, roundedSize, &error);
230 }
231
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100232 if (error != CL_SUCCESS)
233 {
Colm Donelan194086f2022-11-14 17:23:07 +0000234 throw MemoryImportException("ClImportTensorHandle::Invalid imported memory: " + std::to_string(error));
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100235 }
236
237 cl::Buffer wrappedBuffer(buffer);
238 arm_compute::Status status = m_Tensor.allocator()->import_memory(wrappedBuffer);
239
240 // Use the overloaded bool operator of Status to check if it is success, if not throw an exception
241 // with the Status error message
242 bool imported = (status.error_code() == arm_compute::ErrorCode::OK);
243 if (!imported)
244 {
245 throw MemoryImportException(status.error_description());
246 }
247
248 ARMNN_ASSERT(!m_Tensor.info()->is_resizable());
249 return imported;
250 }
David Monahane4a41dc2021-04-14 16:55:36 +0100251 // Only used for testing
252 void CopyOutTo(void* memory) const override
253 {
254 const_cast<armnn::ClImportTensorHandle*>(this)->Map(true);
255 switch(this->GetDataType())
256 {
257 case arm_compute::DataType::F32:
258 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
259 static_cast<float*>(memory));
260 break;
261 case arm_compute::DataType::U8:
262 case arm_compute::DataType::QASYMM8:
263 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
264 static_cast<uint8_t*>(memory));
265 break;
266 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
267 case arm_compute::DataType::QASYMM8_SIGNED:
268 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
269 static_cast<int8_t*>(memory));
270 break;
271 case arm_compute::DataType::F16:
272 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
273 static_cast<armnn::Half*>(memory));
274 break;
275 case arm_compute::DataType::S16:
276 case arm_compute::DataType::QSYMM16:
277 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
278 static_cast<int16_t*>(memory));
279 break;
280 case arm_compute::DataType::S32:
281 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
282 static_cast<int32_t*>(memory));
283 break;
284 default:
285 {
286 throw armnn::UnimplementedException();
287 }
288 }
289 const_cast<armnn::ClImportTensorHandle*>(this)->Unmap();
290 }
291
292 // Only used for testing
293 void CopyInFrom(const void* memory) override
294 {
295 this->Map(true);
296 switch(this->GetDataType())
297 {
298 case arm_compute::DataType::F32:
299 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
300 this->GetTensor());
301 break;
302 case arm_compute::DataType::U8:
303 case arm_compute::DataType::QASYMM8:
304 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
305 this->GetTensor());
306 break;
307 case arm_compute::DataType::F16:
308 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
309 this->GetTensor());
310 break;
311 case arm_compute::DataType::S16:
312 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
313 case arm_compute::DataType::QASYMM8_SIGNED:
314 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
315 this->GetTensor());
316 break;
317 case arm_compute::DataType::QSYMM16:
318 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
319 this->GetTensor());
320 break;
321 case arm_compute::DataType::S32:
322 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
323 this->GetTensor());
324 break;
325 default:
326 {
327 throw armnn::UnimplementedException();
328 }
329 }
330 this->Unmap();
331 }
332
333 arm_compute::CLTensor m_Tensor;
334 MemorySourceFlags m_ImportFlags;
David Monahan6642b8a2021-11-04 16:31:46 +0000335 bool m_Imported;
David Monahane4a41dc2021-04-14 16:55:36 +0100336};
337
Narumol Prangnawarat9ef36142022-01-25 15:15:34 +0000338class ClImportSubTensorHandle : public IClTensorHandle
David Monahane4a41dc2021-04-14 16:55:36 +0100339{
340public:
Narumol Prangnawarat9ef36142022-01-25 15:15:34 +0000341 ClImportSubTensorHandle(IClTensorHandle* parent,
342 const arm_compute::TensorShape& shape,
343 const arm_compute::Coordinates& coords)
David Monahane4a41dc2021-04-14 16:55:36 +0100344 : m_Tensor(&parent->GetTensor(), shape, coords)
345 {
346 parentHandle = parent;
347 }
348
349 arm_compute::CLSubTensor& GetTensor() override { return m_Tensor; }
350 arm_compute::CLSubTensor const& GetTensor() const override { return m_Tensor; }
351
352 virtual void Allocate() override {}
353 virtual void Manage() override {}
354
355 virtual const void* Map(bool blocking = true) const override
356 {
357 IgnoreUnused(blocking);
358 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
359 }
360 virtual void Unmap() const override {}
361
362 virtual ITensorHandle* GetParent() const override { return parentHandle; }
363
364 virtual arm_compute::DataType GetDataType() const override
365 {
366 return m_Tensor.info()->data_type();
367 }
368
369 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
370 {
371 IgnoreUnused(memoryGroup);
372 }
373
374 TensorShape GetStrides() const override
375 {
376 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
377 }
378
379 TensorShape GetShape() const override
380 {
381 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
382 }
383
384private:
385 // Only used for testing
386 void CopyOutTo(void* memory) const override
387 {
388 const_cast<ClImportSubTensorHandle*>(this)->Map(true);
389 switch(this->GetDataType())
390 {
391 case arm_compute::DataType::F32:
392 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
393 static_cast<float*>(memory));
394 break;
395 case arm_compute::DataType::U8:
396 case arm_compute::DataType::QASYMM8:
397 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
398 static_cast<uint8_t*>(memory));
399 break;
400 case arm_compute::DataType::F16:
401 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
402 static_cast<armnn::Half*>(memory));
403 break;
404 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
405 case arm_compute::DataType::QASYMM8_SIGNED:
406 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
407 static_cast<int8_t*>(memory));
408 break;
409 case arm_compute::DataType::S16:
410 case arm_compute::DataType::QSYMM16:
411 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
412 static_cast<int16_t*>(memory));
413 break;
414 case arm_compute::DataType::S32:
415 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
416 static_cast<int32_t*>(memory));
417 break;
418 default:
419 {
420 throw armnn::UnimplementedException();
421 }
422 }
423 const_cast<ClImportSubTensorHandle*>(this)->Unmap();
424 }
425
426 // Only used for testing
427 void CopyInFrom(const void* memory) override
428 {
429 this->Map(true);
430 switch(this->GetDataType())
431 {
432 case arm_compute::DataType::F32:
433 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
434 this->GetTensor());
435 break;
436 case arm_compute::DataType::U8:
437 case arm_compute::DataType::QASYMM8:
438 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
439 this->GetTensor());
440 break;
441 case arm_compute::DataType::F16:
442 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
443 this->GetTensor());
444 break;
445 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
446 case arm_compute::DataType::QASYMM8_SIGNED:
447 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
448 this->GetTensor());
449 break;
450 case arm_compute::DataType::S16:
451 case arm_compute::DataType::QSYMM16:
452 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
453 this->GetTensor());
454 break;
455 case arm_compute::DataType::S32:
456 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
457 this->GetTensor());
458 break;
459 default:
460 {
461 throw armnn::UnimplementedException();
462 }
463 }
464 this->Unmap();
465 }
466
467 mutable arm_compute::CLSubTensor m_Tensor;
468 ITensorHandle* parentHandle = nullptr;
469};
470
471} // namespace armnn