blob: ddd20fc491bfa2087c243f1e581ea847af8e5be6 [file] [log] [blame]
David Monahane4a41dc2021-04-14 16:55:36 +01001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <aclCommon/ArmComputeTensorHandle.hpp>
9#include <aclCommon/ArmComputeTensorUtils.hpp>
10
11#include <Half.hpp>
12
13#include <armnn/utility/PolymorphicDowncast.hpp>
14
15#include <arm_compute/runtime/CL/CLTensor.h>
16#include <arm_compute/runtime/CL/CLSubTensor.h>
17#include <arm_compute/runtime/IMemoryGroup.h>
18#include <arm_compute/runtime/MemoryGroup.h>
19#include <arm_compute/core/TensorShape.h>
20#include <arm_compute/core/Coordinates.h>
21
22#include <include/CL/cl_ext.h>
23#include <arm_compute/core/CL/CLKernelLibrary.h>
24
25namespace armnn
26{
27
28class IClImportTensorHandle : public IAclTensorHandle
29{
30public:
31 virtual arm_compute::ICLTensor& GetTensor() = 0;
32 virtual arm_compute::ICLTensor const& GetTensor() const = 0;
33 virtual arm_compute::DataType GetDataType() const = 0;
34 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) = 0;
35};
36
37class ClImportTensorHandle : public IClImportTensorHandle
38{
39public:
40 ClImportTensorHandle(const TensorInfo& tensorInfo, MemorySourceFlags importFlags)
41 : m_ImportFlags(importFlags)
42 {
43 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
44 }
45
46 ClImportTensorHandle(const TensorInfo& tensorInfo,
47 DataLayout dataLayout,
48 MemorySourceFlags importFlags)
49 : m_ImportFlags(importFlags)
50 {
51 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
52 }
53
54 arm_compute::CLTensor& GetTensor() override { return m_Tensor; }
55 arm_compute::CLTensor const& GetTensor() const override { return m_Tensor; }
56 virtual void Allocate() override {}
57 virtual void Manage() override {}
58
59 virtual const void* Map(bool blocking = true) const override
60 {
61 IgnoreUnused(blocking);
62 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
63 }
64
65 virtual void Unmap() const override {}
66
67 virtual ITensorHandle* GetParent() const override { return nullptr; }
68
69 virtual arm_compute::DataType GetDataType() const override
70 {
71 return m_Tensor.info()->data_type();
72 }
73
74 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
75 {
76 IgnoreUnused(memoryGroup);
77 }
78
79 TensorShape GetStrides() const override
80 {
81 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
82 }
83
84 TensorShape GetShape() const override
85 {
86 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
87 }
88
89 void SetImportFlags(MemorySourceFlags importFlags)
90 {
91 m_ImportFlags = importFlags;
92 }
93
94 MemorySourceFlags GetImportFlags() const override
95 {
96 return m_ImportFlags;
97 }
98
99 virtual bool Import(void* memory, MemorySource source) override
100 {
101 if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
102 {
103 if (source == MemorySource::Malloc)
104 {
David Monahane4a41dc2021-04-14 16:55:36 +0100105 const cl_import_properties_arm importProperties[] =
106 {
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100107 CL_IMPORT_TYPE_ARM,
108 CL_IMPORT_TYPE_HOST_ARM,
109 0
David Monahane4a41dc2021-04-14 16:55:36 +0100110 };
111
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100112 return ClImport(importProperties, memory);
113 }
114 if (source == MemorySource::DmaBuf)
115 {
116 const cl_import_properties_arm importProperties[] =
David Monahane4a41dc2021-04-14 16:55:36 +0100117 {
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100118 CL_IMPORT_TYPE_ARM,
119 CL_IMPORT_TYPE_DMA_BUF_ARM,
120 0
121 };
David Monahane4a41dc2021-04-14 16:55:36 +0100122
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100123 return ClImport(importProperties, memory);
David Monahane4a41dc2021-04-14 16:55:36 +0100124
David Monahane4a41dc2021-04-14 16:55:36 +0100125 }
126 else
127 {
128 throw MemoryImportException("ClImportTensorHandle::Import flag is not supported");
129 }
130 }
131 else
132 {
133 throw MemoryImportException("ClImportTensorHandle::Incorrect import flag");
134 }
135 return false;
136 }
137
138private:
Narumol Prangnawaratff9a29d2021-05-10 11:02:58 +0100139 bool ClImport(const cl_import_properties_arm* importProperties, void* memory)
140 {
141 const size_t totalBytes = m_Tensor.info()->total_size();
142 cl_int error = CL_SUCCESS;
143 cl_mem buffer = clImportMemoryARM(arm_compute::CLKernelLibrary::get().context().get(),
144 CL_MEM_READ_WRITE, importProperties, memory, totalBytes, &error);
145 if (error != CL_SUCCESS)
146 {
147 throw MemoryImportException("ClImportTensorHandle::Invalid imported memory");
148 }
149
150 cl::Buffer wrappedBuffer(buffer);
151 arm_compute::Status status = m_Tensor.allocator()->import_memory(wrappedBuffer);
152
153 // Use the overloaded bool operator of Status to check if it is success, if not throw an exception
154 // with the Status error message
155 bool imported = (status.error_code() == arm_compute::ErrorCode::OK);
156 if (!imported)
157 {
158 throw MemoryImportException(status.error_description());
159 }
160
161 ARMNN_ASSERT(!m_Tensor.info()->is_resizable());
162 return imported;
163 }
David Monahane4a41dc2021-04-14 16:55:36 +0100164 // Only used for testing
165 void CopyOutTo(void* memory) const override
166 {
167 const_cast<armnn::ClImportTensorHandle*>(this)->Map(true);
168 switch(this->GetDataType())
169 {
170 case arm_compute::DataType::F32:
171 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
172 static_cast<float*>(memory));
173 break;
174 case arm_compute::DataType::U8:
175 case arm_compute::DataType::QASYMM8:
176 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
177 static_cast<uint8_t*>(memory));
178 break;
179 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
180 case arm_compute::DataType::QASYMM8_SIGNED:
181 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
182 static_cast<int8_t*>(memory));
183 break;
184 case arm_compute::DataType::F16:
185 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
186 static_cast<armnn::Half*>(memory));
187 break;
188 case arm_compute::DataType::S16:
189 case arm_compute::DataType::QSYMM16:
190 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
191 static_cast<int16_t*>(memory));
192 break;
193 case arm_compute::DataType::S32:
194 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
195 static_cast<int32_t*>(memory));
196 break;
197 default:
198 {
199 throw armnn::UnimplementedException();
200 }
201 }
202 const_cast<armnn::ClImportTensorHandle*>(this)->Unmap();
203 }
204
205 // Only used for testing
206 void CopyInFrom(const void* memory) override
207 {
208 this->Map(true);
209 switch(this->GetDataType())
210 {
211 case arm_compute::DataType::F32:
212 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
213 this->GetTensor());
214 break;
215 case arm_compute::DataType::U8:
216 case arm_compute::DataType::QASYMM8:
217 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
218 this->GetTensor());
219 break;
220 case arm_compute::DataType::F16:
221 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
222 this->GetTensor());
223 break;
224 case arm_compute::DataType::S16:
225 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
226 case arm_compute::DataType::QASYMM8_SIGNED:
227 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
228 this->GetTensor());
229 break;
230 case arm_compute::DataType::QSYMM16:
231 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
232 this->GetTensor());
233 break;
234 case arm_compute::DataType::S32:
235 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
236 this->GetTensor());
237 break;
238 default:
239 {
240 throw armnn::UnimplementedException();
241 }
242 }
243 this->Unmap();
244 }
245
246 arm_compute::CLTensor m_Tensor;
247 MemorySourceFlags m_ImportFlags;
248};
249
250class ClImportSubTensorHandle : public IClImportTensorHandle
251{
252public:
253 ClImportSubTensorHandle(IClImportTensorHandle* parent,
254 const arm_compute::TensorShape& shape,
255 const arm_compute::Coordinates& coords)
256 : m_Tensor(&parent->GetTensor(), shape, coords)
257 {
258 parentHandle = parent;
259 }
260
261 arm_compute::CLSubTensor& GetTensor() override { return m_Tensor; }
262 arm_compute::CLSubTensor const& GetTensor() const override { return m_Tensor; }
263
264 virtual void Allocate() override {}
265 virtual void Manage() override {}
266
267 virtual const void* Map(bool blocking = true) const override
268 {
269 IgnoreUnused(blocking);
270 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
271 }
272 virtual void Unmap() const override {}
273
274 virtual ITensorHandle* GetParent() const override { return parentHandle; }
275
276 virtual arm_compute::DataType GetDataType() const override
277 {
278 return m_Tensor.info()->data_type();
279 }
280
281 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
282 {
283 IgnoreUnused(memoryGroup);
284 }
285
286 TensorShape GetStrides() const override
287 {
288 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
289 }
290
291 TensorShape GetShape() const override
292 {
293 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
294 }
295
296private:
297 // Only used for testing
298 void CopyOutTo(void* memory) const override
299 {
300 const_cast<ClImportSubTensorHandle*>(this)->Map(true);
301 switch(this->GetDataType())
302 {
303 case arm_compute::DataType::F32:
304 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
305 static_cast<float*>(memory));
306 break;
307 case arm_compute::DataType::U8:
308 case arm_compute::DataType::QASYMM8:
309 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
310 static_cast<uint8_t*>(memory));
311 break;
312 case arm_compute::DataType::F16:
313 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
314 static_cast<armnn::Half*>(memory));
315 break;
316 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
317 case arm_compute::DataType::QASYMM8_SIGNED:
318 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
319 static_cast<int8_t*>(memory));
320 break;
321 case arm_compute::DataType::S16:
322 case arm_compute::DataType::QSYMM16:
323 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
324 static_cast<int16_t*>(memory));
325 break;
326 case arm_compute::DataType::S32:
327 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
328 static_cast<int32_t*>(memory));
329 break;
330 default:
331 {
332 throw armnn::UnimplementedException();
333 }
334 }
335 const_cast<ClImportSubTensorHandle*>(this)->Unmap();
336 }
337
338 // Only used for testing
339 void CopyInFrom(const void* memory) override
340 {
341 this->Map(true);
342 switch(this->GetDataType())
343 {
344 case arm_compute::DataType::F32:
345 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
346 this->GetTensor());
347 break;
348 case arm_compute::DataType::U8:
349 case arm_compute::DataType::QASYMM8:
350 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
351 this->GetTensor());
352 break;
353 case arm_compute::DataType::F16:
354 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
355 this->GetTensor());
356 break;
357 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
358 case arm_compute::DataType::QASYMM8_SIGNED:
359 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
360 this->GetTensor());
361 break;
362 case arm_compute::DataType::S16:
363 case arm_compute::DataType::QSYMM16:
364 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
365 this->GetTensor());
366 break;
367 case arm_compute::DataType::S32:
368 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
369 this->GetTensor());
370 break;
371 default:
372 {
373 throw armnn::UnimplementedException();
374 }
375 }
376 this->Unmap();
377 }
378
379 mutable arm_compute::CLSubTensor m_Tensor;
380 ITensorHandle* parentHandle = nullptr;
381};
382
383} // namespace armnn