blob: 48fb2f7d309ba6fb9337f9a430f3a3234fadc4f0 [file] [log] [blame]
David Monahane4a41dc2021-04-14 16:55:36 +01001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <aclCommon/ArmComputeTensorHandle.hpp>
9#include <aclCommon/ArmComputeTensorUtils.hpp>
10
11#include <Half.hpp>
12
13#include <armnn/utility/PolymorphicDowncast.hpp>
14
15#include <arm_compute/runtime/CL/CLTensor.h>
16#include <arm_compute/runtime/CL/CLSubTensor.h>
17#include <arm_compute/runtime/IMemoryGroup.h>
18#include <arm_compute/runtime/MemoryGroup.h>
19#include <arm_compute/core/TensorShape.h>
20#include <arm_compute/core/Coordinates.h>
21
Francis Murtaghe73eda92021-05-21 13:36:54 +010022#include <CL/cl_ext.h>
David Monahane4a41dc2021-04-14 16:55:36 +010023#include <arm_compute/core/CL/CLKernelLibrary.h>
24
25namespace armnn
26{
27
28class IClImportTensorHandle : public IAclTensorHandle
29{
30public:
31 virtual arm_compute::ICLTensor& GetTensor() = 0;
32 virtual arm_compute::ICLTensor const& GetTensor() const = 0;
33 virtual arm_compute::DataType GetDataType() const = 0;
34 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) = 0;
35};
36
37class ClImportTensorHandle : public IClImportTensorHandle
38{
39public:
40 ClImportTensorHandle(const TensorInfo& tensorInfo, MemorySourceFlags importFlags)
41 : m_ImportFlags(importFlags)
42 {
43 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
44 }
45
46 ClImportTensorHandle(const TensorInfo& tensorInfo,
47 DataLayout dataLayout,
48 MemorySourceFlags importFlags)
49 : m_ImportFlags(importFlags)
50 {
51 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
52 }
53
54 arm_compute::CLTensor& GetTensor() override { return m_Tensor; }
55 arm_compute::CLTensor const& GetTensor() const override { return m_Tensor; }
56 virtual void Allocate() override {}
57 virtual void Manage() override {}
58
59 virtual const void* Map(bool blocking = true) const override
60 {
61 IgnoreUnused(blocking);
62 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
63 }
64
65 virtual void Unmap() const override {}
66
67 virtual ITensorHandle* GetParent() const override { return nullptr; }
68
69 virtual arm_compute::DataType GetDataType() const override
70 {
71 return m_Tensor.info()->data_type();
72 }
73
74 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
75 {
76 IgnoreUnused(memoryGroup);
77 }
78
79 TensorShape GetStrides() const override
80 {
81 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
82 }
83
84 TensorShape GetShape() const override
85 {
86 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
87 }
88
89 void SetImportFlags(MemorySourceFlags importFlags)
90 {
91 m_ImportFlags = importFlags;
92 }
93
94 MemorySourceFlags GetImportFlags() const override
95 {
96 return m_ImportFlags;
97 }
98
99 virtual bool Import(void* memory, MemorySource source) override
100 {
101 if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
102 {
103 if (source == MemorySource::Malloc)
104 {
David Monahane4a41dc2021-04-14 16:55:36 +0100105 const cl_import_properties_arm importProperties[] =
106 {
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100107 CL_IMPORT_TYPE_ARM,
108 CL_IMPORT_TYPE_HOST_ARM,
109 0
David Monahane4a41dc2021-04-14 16:55:36 +0100110 };
111
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100112 return ClImport(importProperties, memory);
113 }
114 if (source == MemorySource::DmaBuf)
115 {
116 const cl_import_properties_arm importProperties[] =
David Monahane4a41dc2021-04-14 16:55:36 +0100117 {
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100118 CL_IMPORT_TYPE_ARM,
119 CL_IMPORT_TYPE_DMA_BUF_ARM,
Francis Murtaghf5d5e6c2021-07-26 13:19:33 +0100120 CL_IMPORT_DMA_BUF_DATA_CONSISTENCY_WITH_HOST_ARM,
121 CL_TRUE,
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100122 0
123 };
David Monahane4a41dc2021-04-14 16:55:36 +0100124
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100125 return ClImport(importProperties, memory);
David Monahane4a41dc2021-04-14 16:55:36 +0100126
David Monahane4a41dc2021-04-14 16:55:36 +0100127 }
Francis Murtagh9db96e02021-08-13 16:15:09 +0100128 if (source == MemorySource::DmaBufProtected)
129 {
130 const cl_import_properties_arm importProperties[] =
131 {
132 CL_IMPORT_TYPE_ARM,
133 CL_IMPORT_TYPE_DMA_BUF_ARM,
134 CL_IMPORT_TYPE_PROTECTED_ARM,
135 CL_TRUE,
136 0
137 };
138
139 return ClImport(importProperties, memory, true);
140
141 }
David Monahane4a41dc2021-04-14 16:55:36 +0100142 else
143 {
144 throw MemoryImportException("ClImportTensorHandle::Import flag is not supported");
145 }
146 }
147 else
148 {
149 throw MemoryImportException("ClImportTensorHandle::Incorrect import flag");
150 }
David Monahane4a41dc2021-04-14 16:55:36 +0100151 }
152
153private:
Francis Murtagh9db96e02021-08-13 16:15:09 +0100154 bool ClImport(const cl_import_properties_arm* importProperties, void* memory, bool isProtected = false)
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100155 {
Jan Eilersc1c872f2021-07-22 13:17:04 +0100156 size_t totalBytes = m_Tensor.info()->total_size();
157
158 // Round the size of the buffer to a multiple of the CL_DEVICE_GLOBAL_MEM_CACHELINE_SIZE
159 auto cachelineAlignment =
160 arm_compute::CLKernelLibrary::get().get_device().getInfo<CL_DEVICE_GLOBAL_MEM_CACHELINE_SIZE>();
161 auto roundedSize = cachelineAlignment + totalBytes - (totalBytes % cachelineAlignment);
162
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100163 cl_int error = CL_SUCCESS;
Francis Murtagh9db96e02021-08-13 16:15:09 +0100164 cl_mem buffer;
165 if (isProtected)
166 {
167 buffer = clImportMemoryARM(arm_compute::CLKernelLibrary::get().context().get(),
168 CL_MEM_HOST_NO_ACCESS, importProperties, memory, roundedSize, &error);
169 }
170 else
171 {
172 buffer = clImportMemoryARM(arm_compute::CLKernelLibrary::get().context().get(),
173 CL_MEM_READ_WRITE, importProperties, memory, roundedSize, &error);
174 }
175
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100176 if (error != CL_SUCCESS)
177 {
Francis Murtaghf5d5e6c2021-07-26 13:19:33 +0100178 throw MemoryImportException("ClImportTensorHandle::Invalid imported memory" + std::to_string(error));
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100179 }
180
181 cl::Buffer wrappedBuffer(buffer);
182 arm_compute::Status status = m_Tensor.allocator()->import_memory(wrappedBuffer);
183
184 // Use the overloaded bool operator of Status to check if it is success, if not throw an exception
185 // with the Status error message
186 bool imported = (status.error_code() == arm_compute::ErrorCode::OK);
187 if (!imported)
188 {
189 throw MemoryImportException(status.error_description());
190 }
191
192 ARMNN_ASSERT(!m_Tensor.info()->is_resizable());
193 return imported;
194 }
David Monahane4a41dc2021-04-14 16:55:36 +0100195 // Only used for testing
196 void CopyOutTo(void* memory) const override
197 {
198 const_cast<armnn::ClImportTensorHandle*>(this)->Map(true);
199 switch(this->GetDataType())
200 {
201 case arm_compute::DataType::F32:
202 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
203 static_cast<float*>(memory));
204 break;
205 case arm_compute::DataType::U8:
206 case arm_compute::DataType::QASYMM8:
207 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
208 static_cast<uint8_t*>(memory));
209 break;
210 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
211 case arm_compute::DataType::QASYMM8_SIGNED:
212 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
213 static_cast<int8_t*>(memory));
214 break;
215 case arm_compute::DataType::F16:
216 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
217 static_cast<armnn::Half*>(memory));
218 break;
219 case arm_compute::DataType::S16:
220 case arm_compute::DataType::QSYMM16:
221 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
222 static_cast<int16_t*>(memory));
223 break;
224 case arm_compute::DataType::S32:
225 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
226 static_cast<int32_t*>(memory));
227 break;
228 default:
229 {
230 throw armnn::UnimplementedException();
231 }
232 }
233 const_cast<armnn::ClImportTensorHandle*>(this)->Unmap();
234 }
235
236 // Only used for testing
237 void CopyInFrom(const void* memory) override
238 {
239 this->Map(true);
240 switch(this->GetDataType())
241 {
242 case arm_compute::DataType::F32:
243 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
244 this->GetTensor());
245 break;
246 case arm_compute::DataType::U8:
247 case arm_compute::DataType::QASYMM8:
248 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
249 this->GetTensor());
250 break;
251 case arm_compute::DataType::F16:
252 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
253 this->GetTensor());
254 break;
255 case arm_compute::DataType::S16:
256 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
257 case arm_compute::DataType::QASYMM8_SIGNED:
258 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
259 this->GetTensor());
260 break;
261 case arm_compute::DataType::QSYMM16:
262 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
263 this->GetTensor());
264 break;
265 case arm_compute::DataType::S32:
266 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
267 this->GetTensor());
268 break;
269 default:
270 {
271 throw armnn::UnimplementedException();
272 }
273 }
274 this->Unmap();
275 }
276
277 arm_compute::CLTensor m_Tensor;
278 MemorySourceFlags m_ImportFlags;
279};
280
281class ClImportSubTensorHandle : public IClImportTensorHandle
282{
283public:
284 ClImportSubTensorHandle(IClImportTensorHandle* parent,
285 const arm_compute::TensorShape& shape,
286 const arm_compute::Coordinates& coords)
287 : m_Tensor(&parent->GetTensor(), shape, coords)
288 {
289 parentHandle = parent;
290 }
291
292 arm_compute::CLSubTensor& GetTensor() override { return m_Tensor; }
293 arm_compute::CLSubTensor const& GetTensor() const override { return m_Tensor; }
294
295 virtual void Allocate() override {}
296 virtual void Manage() override {}
297
298 virtual const void* Map(bool blocking = true) const override
299 {
300 IgnoreUnused(blocking);
301 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
302 }
303 virtual void Unmap() const override {}
304
305 virtual ITensorHandle* GetParent() const override { return parentHandle; }
306
307 virtual arm_compute::DataType GetDataType() const override
308 {
309 return m_Tensor.info()->data_type();
310 }
311
312 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
313 {
314 IgnoreUnused(memoryGroup);
315 }
316
317 TensorShape GetStrides() const override
318 {
319 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
320 }
321
322 TensorShape GetShape() const override
323 {
324 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
325 }
326
327private:
328 // Only used for testing
329 void CopyOutTo(void* memory) const override
330 {
331 const_cast<ClImportSubTensorHandle*>(this)->Map(true);
332 switch(this->GetDataType())
333 {
334 case arm_compute::DataType::F32:
335 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
336 static_cast<float*>(memory));
337 break;
338 case arm_compute::DataType::U8:
339 case arm_compute::DataType::QASYMM8:
340 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
341 static_cast<uint8_t*>(memory));
342 break;
343 case arm_compute::DataType::F16:
344 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
345 static_cast<armnn::Half*>(memory));
346 break;
347 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
348 case arm_compute::DataType::QASYMM8_SIGNED:
349 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
350 static_cast<int8_t*>(memory));
351 break;
352 case arm_compute::DataType::S16:
353 case arm_compute::DataType::QSYMM16:
354 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
355 static_cast<int16_t*>(memory));
356 break;
357 case arm_compute::DataType::S32:
358 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
359 static_cast<int32_t*>(memory));
360 break;
361 default:
362 {
363 throw armnn::UnimplementedException();
364 }
365 }
366 const_cast<ClImportSubTensorHandle*>(this)->Unmap();
367 }
368
369 // Only used for testing
370 void CopyInFrom(const void* memory) override
371 {
372 this->Map(true);
373 switch(this->GetDataType())
374 {
375 case arm_compute::DataType::F32:
376 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
377 this->GetTensor());
378 break;
379 case arm_compute::DataType::U8:
380 case arm_compute::DataType::QASYMM8:
381 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
382 this->GetTensor());
383 break;
384 case arm_compute::DataType::F16:
385 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
386 this->GetTensor());
387 break;
388 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
389 case arm_compute::DataType::QASYMM8_SIGNED:
390 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
391 this->GetTensor());
392 break;
393 case arm_compute::DataType::S16:
394 case arm_compute::DataType::QSYMM16:
395 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
396 this->GetTensor());
397 break;
398 case arm_compute::DataType::S32:
399 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
400 this->GetTensor());
401 break;
402 default:
403 {
404 throw armnn::UnimplementedException();
405 }
406 }
407 this->Unmap();
408 }
409
410 mutable arm_compute::CLSubTensor m_Tensor;
411 ITensorHandle* parentHandle = nullptr;
412};
413
414} // namespace armnn