blob: d3c53fa46a7c206c0944c2378be4046eeb68cee8 [file] [log] [blame]
David Monahane4a41dc2021-04-14 16:55:36 +01001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <aclCommon/ArmComputeTensorHandle.hpp>
9#include <aclCommon/ArmComputeTensorUtils.hpp>
10
11#include <Half.hpp>
12
13#include <armnn/utility/PolymorphicDowncast.hpp>
14
15#include <arm_compute/runtime/CL/CLTensor.h>
16#include <arm_compute/runtime/CL/CLSubTensor.h>
17#include <arm_compute/runtime/IMemoryGroup.h>
18#include <arm_compute/runtime/MemoryGroup.h>
19#include <arm_compute/core/TensorShape.h>
20#include <arm_compute/core/Coordinates.h>
21
22#include <include/CL/cl_ext.h>
23#include <arm_compute/core/CL/CLKernelLibrary.h>
24
25namespace armnn
26{
27
28class IClImportTensorHandle : public IAclTensorHandle
29{
30public:
31 virtual arm_compute::ICLTensor& GetTensor() = 0;
32 virtual arm_compute::ICLTensor const& GetTensor() const = 0;
33 virtual arm_compute::DataType GetDataType() const = 0;
34 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) = 0;
35};
36
37class ClImportTensorHandle : public IClImportTensorHandle
38{
39public:
40 ClImportTensorHandle(const TensorInfo& tensorInfo, MemorySourceFlags importFlags)
41 : m_ImportFlags(importFlags)
42 {
43 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
44 }
45
46 ClImportTensorHandle(const TensorInfo& tensorInfo,
47 DataLayout dataLayout,
48 MemorySourceFlags importFlags)
49 : m_ImportFlags(importFlags)
50 {
51 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
52 }
53
54 arm_compute::CLTensor& GetTensor() override { return m_Tensor; }
55 arm_compute::CLTensor const& GetTensor() const override { return m_Tensor; }
56 virtual void Allocate() override {}
57 virtual void Manage() override {}
58
59 virtual const void* Map(bool blocking = true) const override
60 {
61 IgnoreUnused(blocking);
62 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
63 }
64
65 virtual void Unmap() const override {}
66
67 virtual ITensorHandle* GetParent() const override { return nullptr; }
68
69 virtual arm_compute::DataType GetDataType() const override
70 {
71 return m_Tensor.info()->data_type();
72 }
73
74 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
75 {
76 IgnoreUnused(memoryGroup);
77 }
78
79 TensorShape GetStrides() const override
80 {
81 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
82 }
83
84 TensorShape GetShape() const override
85 {
86 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
87 }
88
89 void SetImportFlags(MemorySourceFlags importFlags)
90 {
91 m_ImportFlags = importFlags;
92 }
93
94 MemorySourceFlags GetImportFlags() const override
95 {
96 return m_ImportFlags;
97 }
98
99 virtual bool Import(void* memory, MemorySource source) override
100 {
101 if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
102 {
103 if (source == MemorySource::Malloc)
104 {
105 const size_t totalBytes = m_Tensor.info()->total_size();
106
107 const cl_import_properties_arm importProperties[] =
108 {
109 CL_IMPORT_TYPE_ARM,
110 CL_IMPORT_TYPE_HOST_ARM,
111 0
112 };
113
114 cl_int error = CL_SUCCESS;
115 cl_mem buffer = clImportMemoryARM(arm_compute::CLKernelLibrary::get().context().get(),
116 CL_MEM_READ_WRITE, importProperties, memory, totalBytes, &error);
117 if (error != CL_SUCCESS)
118 {
119 throw MemoryImportException(
120 "ClImportTensorHandle::Invalid imported memory:" + std::to_string(error));
121 }
122
123 cl::Buffer wrappedBuffer(buffer);
124 arm_compute::Status status = m_Tensor.allocator()->import_memory(wrappedBuffer);
125
126 // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
127 // with the Status error message
128 bool imported = (status.error_code() == arm_compute::ErrorCode::OK);
129 if (!imported)
130 {
131 throw MemoryImportException(status.error_description());
132 }
133 ARMNN_ASSERT(!m_Tensor.info()->is_resizable());
134 return imported;
135 }
136 else
137 {
138 throw MemoryImportException("ClImportTensorHandle::Import flag is not supported");
139 }
140 }
141 else
142 {
143 throw MemoryImportException("ClImportTensorHandle::Incorrect import flag");
144 }
145 return false;
146 }
147
148private:
149 // Only used for testing
150 void CopyOutTo(void* memory) const override
151 {
152 const_cast<armnn::ClImportTensorHandle*>(this)->Map(true);
153 switch(this->GetDataType())
154 {
155 case arm_compute::DataType::F32:
156 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
157 static_cast<float*>(memory));
158 break;
159 case arm_compute::DataType::U8:
160 case arm_compute::DataType::QASYMM8:
161 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
162 static_cast<uint8_t*>(memory));
163 break;
164 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
165 case arm_compute::DataType::QASYMM8_SIGNED:
166 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
167 static_cast<int8_t*>(memory));
168 break;
169 case arm_compute::DataType::F16:
170 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
171 static_cast<armnn::Half*>(memory));
172 break;
173 case arm_compute::DataType::S16:
174 case arm_compute::DataType::QSYMM16:
175 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
176 static_cast<int16_t*>(memory));
177 break;
178 case arm_compute::DataType::S32:
179 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
180 static_cast<int32_t*>(memory));
181 break;
182 default:
183 {
184 throw armnn::UnimplementedException();
185 }
186 }
187 const_cast<armnn::ClImportTensorHandle*>(this)->Unmap();
188 }
189
190 // Only used for testing
191 void CopyInFrom(const void* memory) override
192 {
193 this->Map(true);
194 switch(this->GetDataType())
195 {
196 case arm_compute::DataType::F32:
197 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
198 this->GetTensor());
199 break;
200 case arm_compute::DataType::U8:
201 case arm_compute::DataType::QASYMM8:
202 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
203 this->GetTensor());
204 break;
205 case arm_compute::DataType::F16:
206 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
207 this->GetTensor());
208 break;
209 case arm_compute::DataType::S16:
210 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
211 case arm_compute::DataType::QASYMM8_SIGNED:
212 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
213 this->GetTensor());
214 break;
215 case arm_compute::DataType::QSYMM16:
216 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
217 this->GetTensor());
218 break;
219 case arm_compute::DataType::S32:
220 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
221 this->GetTensor());
222 break;
223 default:
224 {
225 throw armnn::UnimplementedException();
226 }
227 }
228 this->Unmap();
229 }
230
231 arm_compute::CLTensor m_Tensor;
232 MemorySourceFlags m_ImportFlags;
233};
234
235class ClImportSubTensorHandle : public IClImportTensorHandle
236{
237public:
238 ClImportSubTensorHandle(IClImportTensorHandle* parent,
239 const arm_compute::TensorShape& shape,
240 const arm_compute::Coordinates& coords)
241 : m_Tensor(&parent->GetTensor(), shape, coords)
242 {
243 parentHandle = parent;
244 }
245
246 arm_compute::CLSubTensor& GetTensor() override { return m_Tensor; }
247 arm_compute::CLSubTensor const& GetTensor() const override { return m_Tensor; }
248
249 virtual void Allocate() override {}
250 virtual void Manage() override {}
251
252 virtual const void* Map(bool blocking = true) const override
253 {
254 IgnoreUnused(blocking);
255 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
256 }
257 virtual void Unmap() const override {}
258
259 virtual ITensorHandle* GetParent() const override { return parentHandle; }
260
261 virtual arm_compute::DataType GetDataType() const override
262 {
263 return m_Tensor.info()->data_type();
264 }
265
266 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
267 {
268 IgnoreUnused(memoryGroup);
269 }
270
271 TensorShape GetStrides() const override
272 {
273 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
274 }
275
276 TensorShape GetShape() const override
277 {
278 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
279 }
280
281private:
282 // Only used for testing
283 void CopyOutTo(void* memory) const override
284 {
285 const_cast<ClImportSubTensorHandle*>(this)->Map(true);
286 switch(this->GetDataType())
287 {
288 case arm_compute::DataType::F32:
289 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
290 static_cast<float*>(memory));
291 break;
292 case arm_compute::DataType::U8:
293 case arm_compute::DataType::QASYMM8:
294 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
295 static_cast<uint8_t*>(memory));
296 break;
297 case arm_compute::DataType::F16:
298 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
299 static_cast<armnn::Half*>(memory));
300 break;
301 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
302 case arm_compute::DataType::QASYMM8_SIGNED:
303 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
304 static_cast<int8_t*>(memory));
305 break;
306 case arm_compute::DataType::S16:
307 case arm_compute::DataType::QSYMM16:
308 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
309 static_cast<int16_t*>(memory));
310 break;
311 case arm_compute::DataType::S32:
312 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
313 static_cast<int32_t*>(memory));
314 break;
315 default:
316 {
317 throw armnn::UnimplementedException();
318 }
319 }
320 const_cast<ClImportSubTensorHandle*>(this)->Unmap();
321 }
322
323 // Only used for testing
324 void CopyInFrom(const void* memory) override
325 {
326 this->Map(true);
327 switch(this->GetDataType())
328 {
329 case arm_compute::DataType::F32:
330 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
331 this->GetTensor());
332 break;
333 case arm_compute::DataType::U8:
334 case arm_compute::DataType::QASYMM8:
335 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
336 this->GetTensor());
337 break;
338 case arm_compute::DataType::F16:
339 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
340 this->GetTensor());
341 break;
342 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
343 case arm_compute::DataType::QASYMM8_SIGNED:
344 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
345 this->GetTensor());
346 break;
347 case arm_compute::DataType::S16:
348 case arm_compute::DataType::QSYMM16:
349 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
350 this->GetTensor());
351 break;
352 case arm_compute::DataType::S32:
353 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
354 this->GetTensor());
355 break;
356 default:
357 {
358 throw armnn::UnimplementedException();
359 }
360 }
361 this->Unmap();
362 }
363
364 mutable arm_compute::CLSubTensor m_Tensor;
365 ITensorHandle* parentHandle = nullptr;
366};
367
368} // namespace armnn