blob: ae8aa5d8c710d86fc7dd4430b30aad61cf571a94 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
Narumol Prangnawarat250d3922020-03-30 16:11:04 +01007#include <BFloat16.hpp>
Aron Virginas-Tar99836d32019-09-30 16:34:31 +01008#include <Half.hpp>
9
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010010#include <armnn/utility/Assert.hpp>
11
Derek Lambertic81855f2019-06-13 17:34:19 +010012#include <aclCommon/ArmComputeTensorHandle.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000013#include <aclCommon/ArmComputeTensorUtils.hpp>
Jan Eilers3c9e0452020-04-10 13:00:44 +010014#include <armnn/utility/PolymorphicDowncast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000015
telsoa01c577f2c2018-08-31 09:22:23 +010016#include <arm_compute/runtime/MemoryGroup.h>
17#include <arm_compute/runtime/IMemoryGroup.h>
telsoa014fcda012018-03-09 14:13:49 +000018#include <arm_compute/runtime/Tensor.h>
19#include <arm_compute/runtime/SubTensor.h>
20#include <arm_compute/core/TensorShape.h>
21#include <arm_compute/core/Coordinates.h>
22
telsoa014fcda012018-03-09 14:13:49 +000023namespace armnn
24{
25
Derek Lambertic81855f2019-06-13 17:34:19 +010026class NeonTensorHandle : public IAclTensorHandle
telsoa014fcda012018-03-09 14:13:49 +000027{
28public:
29 NeonTensorHandle(const TensorInfo& tensorInfo)
David Monahan3fb7e102019-08-20 11:25:29 +010030 : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc)),
31 m_Imported(false),
32 m_IsImportEnabled(false)
telsoa014fcda012018-03-09 14:13:49 +000033 {
34 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
35 }
36
David Monahan3fb7e102019-08-20 11:25:29 +010037 NeonTensorHandle(const TensorInfo& tensorInfo,
38 DataLayout dataLayout,
39 MemorySourceFlags importFlags = static_cast<MemorySourceFlags>(MemorySource::Malloc))
40 : m_ImportFlags(importFlags),
41 m_Imported(false),
42 m_IsImportEnabled(false)
43
Francis Murtagh351d13d2018-09-24 15:01:18 +010044 {
45 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
46 }
47
telsoa014fcda012018-03-09 14:13:49 +000048 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
49 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +010050
telsoa014fcda012018-03-09 14:13:49 +000051 virtual void Allocate() override
52 {
David Monahan3fb7e102019-08-20 11:25:29 +010053 // If we have enabled Importing, don't Allocate the tensor
54 if (!m_IsImportEnabled)
55 {
56 armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
57 }
telsoa014fcda012018-03-09 14:13:49 +000058 };
59
telsoa01c577f2c2018-08-31 09:22:23 +010060 virtual void Manage() override
61 {
David Monahan3fb7e102019-08-20 11:25:29 +010062 // If we have enabled Importing, don't manage the tensor
63 if (!m_IsImportEnabled)
64 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010065 ARMNN_ASSERT(m_MemoryGroup != nullptr);
David Monahan3fb7e102019-08-20 11:25:29 +010066 m_MemoryGroup->manage(&m_Tensor);
67 }
telsoa01c577f2c2018-08-31 09:22:23 +010068 }
69
telsoa01c577f2c2018-08-31 09:22:23 +010070 virtual ITensorHandle* GetParent() const override { return nullptr; }
71
telsoa014fcda012018-03-09 14:13:49 +000072 virtual arm_compute::DataType GetDataType() const override
73 {
74 return m_Tensor.info()->data_type();
75 }
76
telsoa01c577f2c2018-08-31 09:22:23 +010077 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
78 {
Jan Eilers3c9e0452020-04-10 13:00:44 +010079 m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup);
telsoa01c577f2c2018-08-31 09:22:23 +010080 }
81
82 virtual const void* Map(bool /* blocking = true */) const override
83 {
84 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
85 }
telsoa01c577f2c2018-08-31 09:22:23 +010086
David Monahan3fb7e102019-08-20 11:25:29 +010087 virtual void Unmap() const override {}
telsoa01c577f2c2018-08-31 09:22:23 +010088
89 TensorShape GetStrides() const override
90 {
91 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
92 }
93
94 TensorShape GetShape() const override
95 {
96 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
97 }
98
David Monahan3fb7e102019-08-20 11:25:29 +010099 void SetImportFlags(MemorySourceFlags importFlags)
100 {
101 m_ImportFlags = importFlags;
102 }
103
104 MemorySourceFlags GetImportFlags() const override
105 {
106 return m_ImportFlags;
107 }
108
109 void SetImportEnabledFlag(bool importEnabledFlag)
110 {
111 m_IsImportEnabled = importEnabledFlag;
112 }
113
114 virtual bool Import(void* memory, MemorySource source) override
115 {
116 if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
117 {
118 if (source == MemorySource::Malloc && m_IsImportEnabled)
119 {
120 // Checks the 16 byte memory alignment
121 constexpr uintptr_t alignment = sizeof(size_t);
122 if (reinterpret_cast<uintptr_t>(memory) % alignment)
123 {
124 throw MemoryImportException("NeonTensorHandle::Import Attempting to import unaligned memory");
125 }
126
127 // m_Tensor not yet Allocated
128 if (!m_Imported && !m_Tensor.buffer())
129 {
130 arm_compute::Status status = m_Tensor.allocator()->import_memory(memory);
131 // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
132 // with the Status error message
133 m_Imported = bool(status);
134 if (!m_Imported)
135 {
136 throw MemoryImportException(status.error_description());
137 }
138 return m_Imported;
139 }
140
141 // m_Tensor.buffer() initially allocated with Allocate().
142 if (!m_Imported && m_Tensor.buffer())
143 {
144 throw MemoryImportException(
145 "NeonTensorHandle::Import Attempting to import on an already allocated tensor");
146 }
147
148 // m_Tensor.buffer() previously imported.
149 if (m_Imported)
150 {
151 arm_compute::Status status = m_Tensor.allocator()->import_memory(memory);
152 // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
153 // with the Status error message
154 m_Imported = bool(status);
155 if (!m_Imported)
156 {
157 throw MemoryImportException(status.error_description());
158 }
159 return m_Imported;
160 }
161 }
Narumol Prangnawarata2493a02020-08-19 14:39:07 +0100162 else
163 {
164 throw MemoryImportException("NeonTensorHandle::Import is disabled");
165 }
166 }
167 else
168 {
169 throw MemoryImportException("NeonTensorHandle::Incorrect import flag");
David Monahan3fb7e102019-08-20 11:25:29 +0100170 }
171 return false;
172 }
173
telsoa014fcda012018-03-09 14:13:49 +0000174private:
David Beck09e2f272018-10-30 11:38:41 +0000175 // Only used for testing
176 void CopyOutTo(void* memory) const override
177 {
178 switch (this->GetDataType())
179 {
180 case arm_compute::DataType::F32:
181 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
182 static_cast<float*>(memory));
183 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000184 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000185 case arm_compute::DataType::QASYMM8:
186 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
187 static_cast<uint8_t*>(memory));
188 break;
Sadik Armagan48f011e2021-04-21 10:50:34 +0100189 case arm_compute::DataType::QSYMM8:
Sadik Armagane5d0b932020-04-09 15:48:44 +0100190 case arm_compute::DataType::QASYMM8_SIGNED:
191 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
192 static_cast<int8_t*>(memory));
193 break;
Narumol Prangnawarat250d3922020-03-30 16:11:04 +0100194 case arm_compute::DataType::BFLOAT16:
195 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
196 static_cast<armnn::BFloat16*>(memory));
197 break;
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100198 case arm_compute::DataType::F16:
199 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
200 static_cast<armnn::Half*>(memory));
201 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100202 case arm_compute::DataType::S16:
203 case arm_compute::DataType::QSYMM16:
204 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
205 static_cast<int16_t*>(memory));
206 break;
James Conroyd47a0642019-09-17 14:22:06 +0100207 case arm_compute::DataType::S32:
208 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
209 static_cast<int32_t*>(memory));
210 break;
David Beck09e2f272018-10-30 11:38:41 +0000211 default:
212 {
213 throw armnn::UnimplementedException();
214 }
215 }
216 }
217
218 // Only used for testing
219 void CopyInFrom(const void* memory) override
220 {
221 switch (this->GetDataType())
222 {
223 case arm_compute::DataType::F32:
224 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
225 this->GetTensor());
226 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000227 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000228 case arm_compute::DataType::QASYMM8:
229 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
230 this->GetTensor());
231 break;
Sadik Armagan48f011e2021-04-21 10:50:34 +0100232 case arm_compute::DataType::QSYMM8:
Sadik Armagane5d0b932020-04-09 15:48:44 +0100233 case arm_compute::DataType::QASYMM8_SIGNED:
234 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
235 this->GetTensor());
236 break;
Narumol Prangnawarat250d3922020-03-30 16:11:04 +0100237 case arm_compute::DataType::BFLOAT16:
238 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::BFloat16*>(memory),
239 this->GetTensor());
240 break;
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100241 case arm_compute::DataType::F16:
242 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
243 this->GetTensor());
244 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100245 case arm_compute::DataType::S16:
246 case arm_compute::DataType::QSYMM16:
247 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
248 this->GetTensor());
249 break;
James Conroyd47a0642019-09-17 14:22:06 +0100250 case arm_compute::DataType::S32:
251 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
252 this->GetTensor());
253 break;
David Beck09e2f272018-10-30 11:38:41 +0000254 default:
255 {
256 throw armnn::UnimplementedException();
257 }
258 }
259 }
260
telsoa014fcda012018-03-09 14:13:49 +0000261 arm_compute::Tensor m_Tensor;
telsoa01c577f2c2018-08-31 09:22:23 +0100262 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
David Monahan3fb7e102019-08-20 11:25:29 +0100263 MemorySourceFlags m_ImportFlags;
264 bool m_Imported;
265 bool m_IsImportEnabled;
telsoa014fcda012018-03-09 14:13:49 +0000266};
267
Derek Lambertic81855f2019-06-13 17:34:19 +0100268class NeonSubTensorHandle : public IAclTensorHandle
telsoa014fcda012018-03-09 14:13:49 +0000269{
270public:
Derek Lambertic81855f2019-06-13 17:34:19 +0100271 NeonSubTensorHandle(IAclTensorHandle* parent,
telsoa01c577f2c2018-08-31 09:22:23 +0100272 const arm_compute::TensorShape& shape,
273 const arm_compute::Coordinates& coords)
274 : m_Tensor(&parent->GetTensor(), shape, coords)
telsoa014fcda012018-03-09 14:13:49 +0000275 {
telsoa01c577f2c2018-08-31 09:22:23 +0100276 parentHandle = parent;
telsoa014fcda012018-03-09 14:13:49 +0000277 }
278
279 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
280 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +0100281
282 virtual void Allocate() override {}
283 virtual void Manage() override {}
telsoa014fcda012018-03-09 14:13:49 +0000284
telsoa01c577f2c2018-08-31 09:22:23 +0100285 virtual ITensorHandle* GetParent() const override { return parentHandle; }
286
telsoa014fcda012018-03-09 14:13:49 +0000287 virtual arm_compute::DataType GetDataType() const override
288 {
289 return m_Tensor.info()->data_type();
290 }
291
telsoa01c577f2c2018-08-31 09:22:23 +0100292 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
293
294 virtual const void* Map(bool /* blocking = true */) const override
295 {
296 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
297 }
298 virtual void Unmap() const override {}
299
300 TensorShape GetStrides() const override
301 {
302 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
303 }
304
305 TensorShape GetShape() const override
306 {
307 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
308 }
David Beck09e2f272018-10-30 11:38:41 +0000309
telsoa014fcda012018-03-09 14:13:49 +0000310private:
David Beck09e2f272018-10-30 11:38:41 +0000311 // Only used for testing
312 void CopyOutTo(void* memory) const override
313 {
314 switch (this->GetDataType())
315 {
316 case arm_compute::DataType::F32:
317 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
318 static_cast<float*>(memory));
319 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000320 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000321 case arm_compute::DataType::QASYMM8:
322 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
323 static_cast<uint8_t*>(memory));
324 break;
Sadik Armagan48f011e2021-04-21 10:50:34 +0100325 case arm_compute::DataType::QSYMM8:
Sadik Armagane5d0b932020-04-09 15:48:44 +0100326 case arm_compute::DataType::QASYMM8_SIGNED:
327 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
328 static_cast<int8_t*>(memory));
329 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100330 case arm_compute::DataType::S16:
331 case arm_compute::DataType::QSYMM16:
332 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
333 static_cast<int16_t*>(memory));
334 break;
James Conroyd47a0642019-09-17 14:22:06 +0100335 case arm_compute::DataType::S32:
336 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
337 static_cast<int32_t*>(memory));
338 break;
David Beck09e2f272018-10-30 11:38:41 +0000339 default:
340 {
341 throw armnn::UnimplementedException();
342 }
343 }
344 }
345
346 // Only used for testing
347 void CopyInFrom(const void* memory) override
348 {
349 switch (this->GetDataType())
350 {
351 case arm_compute::DataType::F32:
352 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
353 this->GetTensor());
354 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000355 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000356 case arm_compute::DataType::QASYMM8:
357 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
358 this->GetTensor());
359 break;
Sadik Armagan48f011e2021-04-21 10:50:34 +0100360 case arm_compute::DataType::QSYMM8:
Sadik Armagane5d0b932020-04-09 15:48:44 +0100361 case arm_compute::DataType::QASYMM8_SIGNED:
362 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
363 this->GetTensor());
364 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100365 case arm_compute::DataType::S16:
366 case arm_compute::DataType::QSYMM16:
367 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
368 this->GetTensor());
369 break;
James Conroyd47a0642019-09-17 14:22:06 +0100370 case arm_compute::DataType::S32:
371 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
372 this->GetTensor());
373 break;
David Beck09e2f272018-10-30 11:38:41 +0000374 default:
375 {
376 throw armnn::UnimplementedException();
377 }
378 }
379 }
380
telsoa01c577f2c2018-08-31 09:22:23 +0100381 arm_compute::SubTensor m_Tensor;
382 ITensorHandle* parentHandle = nullptr;
telsoa014fcda012018-03-09 14:13:49 +0000383};
384
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100385} // namespace armnn