blob: 3fca7cb127481fc80fe6fe7b506d8d2176c87a52 [file] [log] [blame]
David Monahane4a41dc2021-04-14 16:55:36 +01001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <aclCommon/ArmComputeTensorHandle.hpp>
9#include <aclCommon/ArmComputeTensorUtils.hpp>
10
11#include <Half.hpp>
12
13#include <armnn/utility/PolymorphicDowncast.hpp>
14
15#include <arm_compute/runtime/CL/CLTensor.h>
16#include <arm_compute/runtime/CL/CLSubTensor.h>
17#include <arm_compute/runtime/IMemoryGroup.h>
18#include <arm_compute/runtime/MemoryGroup.h>
19#include <arm_compute/core/TensorShape.h>
20#include <arm_compute/core/Coordinates.h>
21
Francis Murtaghe73eda92021-05-21 13:36:54 +010022#include <CL/cl_ext.h>
David Monahane4a41dc2021-04-14 16:55:36 +010023#include <arm_compute/core/CL/CLKernelLibrary.h>
24
25namespace armnn
26{
27
28class IClImportTensorHandle : public IAclTensorHandle
29{
30public:
31 virtual arm_compute::ICLTensor& GetTensor() = 0;
32 virtual arm_compute::ICLTensor const& GetTensor() const = 0;
33 virtual arm_compute::DataType GetDataType() const = 0;
34 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) = 0;
35};
36
37class ClImportTensorHandle : public IClImportTensorHandle
38{
39public:
40 ClImportTensorHandle(const TensorInfo& tensorInfo, MemorySourceFlags importFlags)
41 : m_ImportFlags(importFlags)
42 {
43 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
44 }
45
46 ClImportTensorHandle(const TensorInfo& tensorInfo,
47 DataLayout dataLayout,
48 MemorySourceFlags importFlags)
49 : m_ImportFlags(importFlags)
50 {
51 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
52 }
53
54 arm_compute::CLTensor& GetTensor() override { return m_Tensor; }
55 arm_compute::CLTensor const& GetTensor() const override { return m_Tensor; }
56 virtual void Allocate() override {}
57 virtual void Manage() override {}
58
59 virtual const void* Map(bool blocking = true) const override
60 {
61 IgnoreUnused(blocking);
62 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
63 }
64
65 virtual void Unmap() const override {}
66
67 virtual ITensorHandle* GetParent() const override { return nullptr; }
68
69 virtual arm_compute::DataType GetDataType() const override
70 {
71 return m_Tensor.info()->data_type();
72 }
73
74 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
75 {
76 IgnoreUnused(memoryGroup);
77 }
78
79 TensorShape GetStrides() const override
80 {
81 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
82 }
83
84 TensorShape GetShape() const override
85 {
86 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
87 }
88
89 void SetImportFlags(MemorySourceFlags importFlags)
90 {
91 m_ImportFlags = importFlags;
92 }
93
94 MemorySourceFlags GetImportFlags() const override
95 {
96 return m_ImportFlags;
97 }
98
99 virtual bool Import(void* memory, MemorySource source) override
100 {
101 if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
102 {
103 if (source == MemorySource::Malloc)
104 {
David Monahane4a41dc2021-04-14 16:55:36 +0100105 const cl_import_properties_arm importProperties[] =
106 {
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100107 CL_IMPORT_TYPE_ARM,
108 CL_IMPORT_TYPE_HOST_ARM,
109 0
David Monahane4a41dc2021-04-14 16:55:36 +0100110 };
111
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100112 return ClImport(importProperties, memory);
113 }
114 if (source == MemorySource::DmaBuf)
115 {
116 const cl_import_properties_arm importProperties[] =
David Monahane4a41dc2021-04-14 16:55:36 +0100117 {
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100118 CL_IMPORT_TYPE_ARM,
119 CL_IMPORT_TYPE_DMA_BUF_ARM,
Francis Murtaghf5d5e6c2021-07-26 13:19:33 +0100120 CL_IMPORT_DMA_BUF_DATA_CONSISTENCY_WITH_HOST_ARM,
121 CL_TRUE,
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100122 0
123 };
David Monahane4a41dc2021-04-14 16:55:36 +0100124
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100125 return ClImport(importProperties, memory);
David Monahane4a41dc2021-04-14 16:55:36 +0100126
David Monahane4a41dc2021-04-14 16:55:36 +0100127 }
128 else
129 {
130 throw MemoryImportException("ClImportTensorHandle::Import flag is not supported");
131 }
132 }
133 else
134 {
135 throw MemoryImportException("ClImportTensorHandle::Incorrect import flag");
136 }
137 return false;
138 }
139
140private:
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100141 bool ClImport(const cl_import_properties_arm* importProperties, void* memory)
142 {
143 const size_t totalBytes = m_Tensor.info()->total_size();
144 cl_int error = CL_SUCCESS;
145 cl_mem buffer = clImportMemoryARM(arm_compute::CLKernelLibrary::get().context().get(),
146 CL_MEM_READ_WRITE, importProperties, memory, totalBytes, &error);
147 if (error != CL_SUCCESS)
148 {
Francis Murtaghf5d5e6c2021-07-26 13:19:33 +0100149 throw MemoryImportException("ClImportTensorHandle::Invalid imported memory" + std::to_string(error));
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100150 }
151
152 cl::Buffer wrappedBuffer(buffer);
153 arm_compute::Status status = m_Tensor.allocator()->import_memory(wrappedBuffer);
154
155 // Use the overloaded bool operator of Status to check if it is success, if not throw an exception
156 // with the Status error message
157 bool imported = (status.error_code() == arm_compute::ErrorCode::OK);
158 if (!imported)
159 {
160 throw MemoryImportException(status.error_description());
161 }
162
163 ARMNN_ASSERT(!m_Tensor.info()->is_resizable());
164 return imported;
165 }
David Monahane4a41dc2021-04-14 16:55:36 +0100166 // Only used for testing
167 void CopyOutTo(void* memory) const override
168 {
169 const_cast<armnn::ClImportTensorHandle*>(this)->Map(true);
170 switch(this->GetDataType())
171 {
172 case arm_compute::DataType::F32:
173 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
174 static_cast<float*>(memory));
175 break;
176 case arm_compute::DataType::U8:
177 case arm_compute::DataType::QASYMM8:
178 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
179 static_cast<uint8_t*>(memory));
180 break;
181 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
182 case arm_compute::DataType::QASYMM8_SIGNED:
183 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
184 static_cast<int8_t*>(memory));
185 break;
186 case arm_compute::DataType::F16:
187 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
188 static_cast<armnn::Half*>(memory));
189 break;
190 case arm_compute::DataType::S16:
191 case arm_compute::DataType::QSYMM16:
192 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
193 static_cast<int16_t*>(memory));
194 break;
195 case arm_compute::DataType::S32:
196 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
197 static_cast<int32_t*>(memory));
198 break;
199 default:
200 {
201 throw armnn::UnimplementedException();
202 }
203 }
204 const_cast<armnn::ClImportTensorHandle*>(this)->Unmap();
205 }
206
207 // Only used for testing
208 void CopyInFrom(const void* memory) override
209 {
210 this->Map(true);
211 switch(this->GetDataType())
212 {
213 case arm_compute::DataType::F32:
214 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
215 this->GetTensor());
216 break;
217 case arm_compute::DataType::U8:
218 case arm_compute::DataType::QASYMM8:
219 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
220 this->GetTensor());
221 break;
222 case arm_compute::DataType::F16:
223 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
224 this->GetTensor());
225 break;
226 case arm_compute::DataType::S16:
227 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
228 case arm_compute::DataType::QASYMM8_SIGNED:
229 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
230 this->GetTensor());
231 break;
232 case arm_compute::DataType::QSYMM16:
233 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
234 this->GetTensor());
235 break;
236 case arm_compute::DataType::S32:
237 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
238 this->GetTensor());
239 break;
240 default:
241 {
242 throw armnn::UnimplementedException();
243 }
244 }
245 this->Unmap();
246 }
247
248 arm_compute::CLTensor m_Tensor;
249 MemorySourceFlags m_ImportFlags;
250};
251
252class ClImportSubTensorHandle : public IClImportTensorHandle
253{
254public:
255 ClImportSubTensorHandle(IClImportTensorHandle* parent,
256 const arm_compute::TensorShape& shape,
257 const arm_compute::Coordinates& coords)
258 : m_Tensor(&parent->GetTensor(), shape, coords)
259 {
260 parentHandle = parent;
261 }
262
263 arm_compute::CLSubTensor& GetTensor() override { return m_Tensor; }
264 arm_compute::CLSubTensor const& GetTensor() const override { return m_Tensor; }
265
266 virtual void Allocate() override {}
267 virtual void Manage() override {}
268
269 virtual const void* Map(bool blocking = true) const override
270 {
271 IgnoreUnused(blocking);
272 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
273 }
274 virtual void Unmap() const override {}
275
276 virtual ITensorHandle* GetParent() const override { return parentHandle; }
277
278 virtual arm_compute::DataType GetDataType() const override
279 {
280 return m_Tensor.info()->data_type();
281 }
282
283 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
284 {
285 IgnoreUnused(memoryGroup);
286 }
287
288 TensorShape GetStrides() const override
289 {
290 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
291 }
292
293 TensorShape GetShape() const override
294 {
295 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
296 }
297
298private:
299 // Only used for testing
300 void CopyOutTo(void* memory) const override
301 {
302 const_cast<ClImportSubTensorHandle*>(this)->Map(true);
303 switch(this->GetDataType())
304 {
305 case arm_compute::DataType::F32:
306 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
307 static_cast<float*>(memory));
308 break;
309 case arm_compute::DataType::U8:
310 case arm_compute::DataType::QASYMM8:
311 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
312 static_cast<uint8_t*>(memory));
313 break;
314 case arm_compute::DataType::F16:
315 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
316 static_cast<armnn::Half*>(memory));
317 break;
318 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
319 case arm_compute::DataType::QASYMM8_SIGNED:
320 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
321 static_cast<int8_t*>(memory));
322 break;
323 case arm_compute::DataType::S16:
324 case arm_compute::DataType::QSYMM16:
325 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
326 static_cast<int16_t*>(memory));
327 break;
328 case arm_compute::DataType::S32:
329 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
330 static_cast<int32_t*>(memory));
331 break;
332 default:
333 {
334 throw armnn::UnimplementedException();
335 }
336 }
337 const_cast<ClImportSubTensorHandle*>(this)->Unmap();
338 }
339
340 // Only used for testing
341 void CopyInFrom(const void* memory) override
342 {
343 this->Map(true);
344 switch(this->GetDataType())
345 {
346 case arm_compute::DataType::F32:
347 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
348 this->GetTensor());
349 break;
350 case arm_compute::DataType::U8:
351 case arm_compute::DataType::QASYMM8:
352 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
353 this->GetTensor());
354 break;
355 case arm_compute::DataType::F16:
356 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
357 this->GetTensor());
358 break;
359 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
360 case arm_compute::DataType::QASYMM8_SIGNED:
361 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
362 this->GetTensor());
363 break;
364 case arm_compute::DataType::S16:
365 case arm_compute::DataType::QSYMM16:
366 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
367 this->GetTensor());
368 break;
369 case arm_compute::DataType::S32:
370 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
371 this->GetTensor());
372 break;
373 default:
374 {
375 throw armnn::UnimplementedException();
376 }
377 }
378 this->Unmap();
379 }
380
381 mutable arm_compute::CLSubTensor m_Tensor;
382 ITensorHandle* parentHandle = nullptr;
383};
384
385} // namespace armnn