blob: a24ab5656ec5477bd4edf0034b0d11dbf5c5e061 [file] [log] [blame]
David Monahane4a41dc2021-04-14 16:55:36 +01001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <aclCommon/ArmComputeTensorHandle.hpp>
9#include <aclCommon/ArmComputeTensorUtils.hpp>
10
11#include <Half.hpp>
12
13#include <armnn/utility/PolymorphicDowncast.hpp>
14
15#include <arm_compute/runtime/CL/CLTensor.h>
16#include <arm_compute/runtime/CL/CLSubTensor.h>
17#include <arm_compute/runtime/IMemoryGroup.h>
18#include <arm_compute/runtime/MemoryGroup.h>
19#include <arm_compute/core/TensorShape.h>
20#include <arm_compute/core/Coordinates.h>
21
Francis Murtaghe73eda92021-05-21 13:36:54 +010022#include <CL/cl_ext.h>
David Monahane4a41dc2021-04-14 16:55:36 +010023#include <arm_compute/core/CL/CLKernelLibrary.h>
24
25namespace armnn
26{
27
28class IClImportTensorHandle : public IAclTensorHandle
29{
30public:
31 virtual arm_compute::ICLTensor& GetTensor() = 0;
32 virtual arm_compute::ICLTensor const& GetTensor() const = 0;
33 virtual arm_compute::DataType GetDataType() const = 0;
34 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) = 0;
35};
36
37class ClImportTensorHandle : public IClImportTensorHandle
38{
39public:
40 ClImportTensorHandle(const TensorInfo& tensorInfo, MemorySourceFlags importFlags)
41 : m_ImportFlags(importFlags)
42 {
43 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
44 }
45
46 ClImportTensorHandle(const TensorInfo& tensorInfo,
47 DataLayout dataLayout,
48 MemorySourceFlags importFlags)
David Monahan6642b8a2021-11-04 16:31:46 +000049 : m_ImportFlags(importFlags), m_Imported(false)
David Monahane4a41dc2021-04-14 16:55:36 +010050 {
51 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
52 }
53
54 arm_compute::CLTensor& GetTensor() override { return m_Tensor; }
55 arm_compute::CLTensor const& GetTensor() const override { return m_Tensor; }
56 virtual void Allocate() override {}
57 virtual void Manage() override {}
58
59 virtual const void* Map(bool blocking = true) const override
60 {
61 IgnoreUnused(blocking);
62 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
63 }
64
65 virtual void Unmap() const override {}
66
67 virtual ITensorHandle* GetParent() const override { return nullptr; }
68
69 virtual arm_compute::DataType GetDataType() const override
70 {
71 return m_Tensor.info()->data_type();
72 }
73
74 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
75 {
76 IgnoreUnused(memoryGroup);
77 }
78
79 TensorShape GetStrides() const override
80 {
81 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
82 }
83
84 TensorShape GetShape() const override
85 {
86 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
87 }
88
89 void SetImportFlags(MemorySourceFlags importFlags)
90 {
91 m_ImportFlags = importFlags;
92 }
93
94 MemorySourceFlags GetImportFlags() const override
95 {
96 return m_ImportFlags;
97 }
98
99 virtual bool Import(void* memory, MemorySource source) override
100 {
101 if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
102 {
103 if (source == MemorySource::Malloc)
104 {
David Monahane4a41dc2021-04-14 16:55:36 +0100105 const cl_import_properties_arm importProperties[] =
106 {
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100107 CL_IMPORT_TYPE_ARM,
108 CL_IMPORT_TYPE_HOST_ARM,
109 0
David Monahane4a41dc2021-04-14 16:55:36 +0100110 };
111
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100112 return ClImport(importProperties, memory);
113 }
114 if (source == MemorySource::DmaBuf)
115 {
116 const cl_import_properties_arm importProperties[] =
David Monahane4a41dc2021-04-14 16:55:36 +0100117 {
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100118 CL_IMPORT_TYPE_ARM,
119 CL_IMPORT_TYPE_DMA_BUF_ARM,
Francis Murtaghf5d5e6c2021-07-26 13:19:33 +0100120 CL_IMPORT_DMA_BUF_DATA_CONSISTENCY_WITH_HOST_ARM,
121 CL_TRUE,
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100122 0
123 };
David Monahane4a41dc2021-04-14 16:55:36 +0100124
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100125 return ClImport(importProperties, memory);
David Monahane4a41dc2021-04-14 16:55:36 +0100126
David Monahane4a41dc2021-04-14 16:55:36 +0100127 }
Francis Murtagh9db96e02021-08-13 16:15:09 +0100128 if (source == MemorySource::DmaBufProtected)
129 {
130 const cl_import_properties_arm importProperties[] =
131 {
132 CL_IMPORT_TYPE_ARM,
133 CL_IMPORT_TYPE_DMA_BUF_ARM,
134 CL_IMPORT_TYPE_PROTECTED_ARM,
135 CL_TRUE,
136 0
137 };
138
139 return ClImport(importProperties, memory, true);
140
141 }
David Monahan6642b8a2021-11-04 16:31:46 +0000142 // Case for importing memory allocated by OpenCl externally directly into the tensor
143 else if (source == MemorySource::Gralloc)
144 {
145 // m_Tensor not yet Allocated
146 if (!m_Imported && !m_Tensor.buffer())
147 {
148 // Importing memory allocated by OpenCl into the tensor directly.
149 arm_compute::Status status =
150 m_Tensor.allocator()->import_memory(cl::Buffer(static_cast<cl_mem>(memory)));
151 m_Imported = bool(status);
152 if (!m_Imported)
153 {
154 throw MemoryImportException(status.error_description());
155 }
156 return m_Imported;
157 }
158
159 // m_Tensor.buffer() initially allocated with Allocate().
160 else if (!m_Imported && m_Tensor.buffer())
161 {
162 throw MemoryImportException(
163 "ClImportTensorHandle::Import Attempting to import on an already allocated tensor");
164 }
165
166 // m_Tensor.buffer() previously imported.
167 else if (m_Imported)
168 {
169 // Importing memory allocated by OpenCl into the tensor directly.
170 arm_compute::Status status =
171 m_Tensor.allocator()->import_memory(cl::Buffer(static_cast<cl_mem>(memory)));
172 m_Imported = bool(status);
173 if (!m_Imported)
174 {
175 throw MemoryImportException(status.error_description());
176 }
177 return m_Imported;
178 }
179 else
180 {
181 throw MemoryImportException("ClImportTensorHandle::Failed to Import Gralloc Memory");
182 }
183 }
David Monahane4a41dc2021-04-14 16:55:36 +0100184 else
185 {
186 throw MemoryImportException("ClImportTensorHandle::Import flag is not supported");
187 }
188 }
189 else
190 {
191 throw MemoryImportException("ClImportTensorHandle::Incorrect import flag");
192 }
David Monahane4a41dc2021-04-14 16:55:36 +0100193 }
194
195private:
Francis Murtagh9db96e02021-08-13 16:15:09 +0100196 bool ClImport(const cl_import_properties_arm* importProperties, void* memory, bool isProtected = false)
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100197 {
Jan Eilersc1c872f2021-07-22 13:17:04 +0100198 size_t totalBytes = m_Tensor.info()->total_size();
199
200 // Round the size of the buffer to a multiple of the CL_DEVICE_GLOBAL_MEM_CACHELINE_SIZE
201 auto cachelineAlignment =
202 arm_compute::CLKernelLibrary::get().get_device().getInfo<CL_DEVICE_GLOBAL_MEM_CACHELINE_SIZE>();
203 auto roundedSize = cachelineAlignment + totalBytes - (totalBytes % cachelineAlignment);
204
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100205 cl_int error = CL_SUCCESS;
Francis Murtagh9db96e02021-08-13 16:15:09 +0100206 cl_mem buffer;
207 if (isProtected)
208 {
209 buffer = clImportMemoryARM(arm_compute::CLKernelLibrary::get().context().get(),
210 CL_MEM_HOST_NO_ACCESS, importProperties, memory, roundedSize, &error);
211 }
212 else
213 {
214 buffer = clImportMemoryARM(arm_compute::CLKernelLibrary::get().context().get(),
215 CL_MEM_READ_WRITE, importProperties, memory, roundedSize, &error);
216 }
217
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100218 if (error != CL_SUCCESS)
219 {
Francis Murtaghf5d5e6c2021-07-26 13:19:33 +0100220 throw MemoryImportException("ClImportTensorHandle::Invalid imported memory" + std::to_string(error));
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100221 }
222
223 cl::Buffer wrappedBuffer(buffer);
224 arm_compute::Status status = m_Tensor.allocator()->import_memory(wrappedBuffer);
225
226 // Use the overloaded bool operator of Status to check if it is success, if not throw an exception
227 // with the Status error message
228 bool imported = (status.error_code() == arm_compute::ErrorCode::OK);
229 if (!imported)
230 {
231 throw MemoryImportException(status.error_description());
232 }
233
234 ARMNN_ASSERT(!m_Tensor.info()->is_resizable());
235 return imported;
236 }
David Monahane4a41dc2021-04-14 16:55:36 +0100237 // Only used for testing
238 void CopyOutTo(void* memory) const override
239 {
240 const_cast<armnn::ClImportTensorHandle*>(this)->Map(true);
241 switch(this->GetDataType())
242 {
243 case arm_compute::DataType::F32:
244 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
245 static_cast<float*>(memory));
246 break;
247 case arm_compute::DataType::U8:
248 case arm_compute::DataType::QASYMM8:
249 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
250 static_cast<uint8_t*>(memory));
251 break;
252 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
253 case arm_compute::DataType::QASYMM8_SIGNED:
254 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
255 static_cast<int8_t*>(memory));
256 break;
257 case arm_compute::DataType::F16:
258 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
259 static_cast<armnn::Half*>(memory));
260 break;
261 case arm_compute::DataType::S16:
262 case arm_compute::DataType::QSYMM16:
263 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
264 static_cast<int16_t*>(memory));
265 break;
266 case arm_compute::DataType::S32:
267 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
268 static_cast<int32_t*>(memory));
269 break;
270 default:
271 {
272 throw armnn::UnimplementedException();
273 }
274 }
275 const_cast<armnn::ClImportTensorHandle*>(this)->Unmap();
276 }
277
278 // Only used for testing
279 void CopyInFrom(const void* memory) override
280 {
281 this->Map(true);
282 switch(this->GetDataType())
283 {
284 case arm_compute::DataType::F32:
285 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
286 this->GetTensor());
287 break;
288 case arm_compute::DataType::U8:
289 case arm_compute::DataType::QASYMM8:
290 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
291 this->GetTensor());
292 break;
293 case arm_compute::DataType::F16:
294 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
295 this->GetTensor());
296 break;
297 case arm_compute::DataType::S16:
298 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
299 case arm_compute::DataType::QASYMM8_SIGNED:
300 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
301 this->GetTensor());
302 break;
303 case arm_compute::DataType::QSYMM16:
304 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
305 this->GetTensor());
306 break;
307 case arm_compute::DataType::S32:
308 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
309 this->GetTensor());
310 break;
311 default:
312 {
313 throw armnn::UnimplementedException();
314 }
315 }
316 this->Unmap();
317 }
318
319 arm_compute::CLTensor m_Tensor;
320 MemorySourceFlags m_ImportFlags;
David Monahan6642b8a2021-11-04 16:31:46 +0000321 bool m_Imported;
David Monahane4a41dc2021-04-14 16:55:36 +0100322};
323
324class ClImportSubTensorHandle : public IClImportTensorHandle
325{
326public:
327 ClImportSubTensorHandle(IClImportTensorHandle* parent,
328 const arm_compute::TensorShape& shape,
329 const arm_compute::Coordinates& coords)
330 : m_Tensor(&parent->GetTensor(), shape, coords)
331 {
332 parentHandle = parent;
333 }
334
335 arm_compute::CLSubTensor& GetTensor() override { return m_Tensor; }
336 arm_compute::CLSubTensor const& GetTensor() const override { return m_Tensor; }
337
338 virtual void Allocate() override {}
339 virtual void Manage() override {}
340
341 virtual const void* Map(bool blocking = true) const override
342 {
343 IgnoreUnused(blocking);
344 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
345 }
346 virtual void Unmap() const override {}
347
348 virtual ITensorHandle* GetParent() const override { return parentHandle; }
349
350 virtual arm_compute::DataType GetDataType() const override
351 {
352 return m_Tensor.info()->data_type();
353 }
354
355 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
356 {
357 IgnoreUnused(memoryGroup);
358 }
359
360 TensorShape GetStrides() const override
361 {
362 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
363 }
364
365 TensorShape GetShape() const override
366 {
367 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
368 }
369
370private:
371 // Only used for testing
372 void CopyOutTo(void* memory) const override
373 {
374 const_cast<ClImportSubTensorHandle*>(this)->Map(true);
375 switch(this->GetDataType())
376 {
377 case arm_compute::DataType::F32:
378 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
379 static_cast<float*>(memory));
380 break;
381 case arm_compute::DataType::U8:
382 case arm_compute::DataType::QASYMM8:
383 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
384 static_cast<uint8_t*>(memory));
385 break;
386 case arm_compute::DataType::F16:
387 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
388 static_cast<armnn::Half*>(memory));
389 break;
390 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
391 case arm_compute::DataType::QASYMM8_SIGNED:
392 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
393 static_cast<int8_t*>(memory));
394 break;
395 case arm_compute::DataType::S16:
396 case arm_compute::DataType::QSYMM16:
397 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
398 static_cast<int16_t*>(memory));
399 break;
400 case arm_compute::DataType::S32:
401 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
402 static_cast<int32_t*>(memory));
403 break;
404 default:
405 {
406 throw armnn::UnimplementedException();
407 }
408 }
409 const_cast<ClImportSubTensorHandle*>(this)->Unmap();
410 }
411
412 // Only used for testing
413 void CopyInFrom(const void* memory) override
414 {
415 this->Map(true);
416 switch(this->GetDataType())
417 {
418 case arm_compute::DataType::F32:
419 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
420 this->GetTensor());
421 break;
422 case arm_compute::DataType::U8:
423 case arm_compute::DataType::QASYMM8:
424 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
425 this->GetTensor());
426 break;
427 case arm_compute::DataType::F16:
428 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
429 this->GetTensor());
430 break;
431 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
432 case arm_compute::DataType::QASYMM8_SIGNED:
433 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
434 this->GetTensor());
435 break;
436 case arm_compute::DataType::S16:
437 case arm_compute::DataType::QSYMM16:
438 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
439 this->GetTensor());
440 break;
441 case arm_compute::DataType::S32:
442 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
443 this->GetTensor());
444 break;
445 default:
446 {
447 throw armnn::UnimplementedException();
448 }
449 }
450 this->Unmap();
451 }
452
453 mutable arm_compute::CLSubTensor m_Tensor;
454 ITensorHandle* parentHandle = nullptr;
455};
456
457} // namespace armnn