blob: 9445cb1c75457f447d9100517b76262a2ca999a0 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
Narumol Prangnawarat250d3922020-03-30 16:11:04 +01007#include <BFloat16.hpp>
Aron Virginas-Tar99836d32019-09-30 16:34:31 +01008#include <Half.hpp>
9
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010010#include <armnn/utility/Assert.hpp>
11
Derek Lambertic81855f2019-06-13 17:34:19 +010012#include <aclCommon/ArmComputeTensorHandle.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000013#include <aclCommon/ArmComputeTensorUtils.hpp>
Jan Eilers3c9e0452020-04-10 13:00:44 +010014#include <armnn/utility/PolymorphicDowncast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000015
telsoa01c577f2c2018-08-31 09:22:23 +010016#include <arm_compute/runtime/MemoryGroup.h>
17#include <arm_compute/runtime/IMemoryGroup.h>
telsoa014fcda012018-03-09 14:13:49 +000018#include <arm_compute/runtime/Tensor.h>
19#include <arm_compute/runtime/SubTensor.h>
20#include <arm_compute/core/TensorShape.h>
21#include <arm_compute/core/Coordinates.h>
22
telsoa014fcda012018-03-09 14:13:49 +000023namespace armnn
24{
25
Derek Lambertic81855f2019-06-13 17:34:19 +010026class NeonTensorHandle : public IAclTensorHandle
telsoa014fcda012018-03-09 14:13:49 +000027{
28public:
29 NeonTensorHandle(const TensorInfo& tensorInfo)
David Monahan3fb7e102019-08-20 11:25:29 +010030 : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc)),
31 m_Imported(false),
Finn Williamsb1aad422021-10-28 19:07:32 +010032 m_IsImportEnabled(false),
33 m_TypeAlignment(GetDataTypeSize(tensorInfo.GetDataType()))
telsoa014fcda012018-03-09 14:13:49 +000034 {
35 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
36 }
37
David Monahan3fb7e102019-08-20 11:25:29 +010038 NeonTensorHandle(const TensorInfo& tensorInfo,
39 DataLayout dataLayout,
40 MemorySourceFlags importFlags = static_cast<MemorySourceFlags>(MemorySource::Malloc))
41 : m_ImportFlags(importFlags),
42 m_Imported(false),
Finn Williamsb1aad422021-10-28 19:07:32 +010043 m_IsImportEnabled(false),
44 m_TypeAlignment(GetDataTypeSize(tensorInfo.GetDataType()))
45
David Monahan3fb7e102019-08-20 11:25:29 +010046
Francis Murtagh351d13d2018-09-24 15:01:18 +010047 {
48 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
49 }
50
telsoa014fcda012018-03-09 14:13:49 +000051 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
52 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +010053
telsoa014fcda012018-03-09 14:13:49 +000054 virtual void Allocate() override
55 {
David Monahan3fb7e102019-08-20 11:25:29 +010056 // If we have enabled Importing, don't Allocate the tensor
57 if (!m_IsImportEnabled)
58 {
59 armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
60 }
telsoa014fcda012018-03-09 14:13:49 +000061 };
62
telsoa01c577f2c2018-08-31 09:22:23 +010063 virtual void Manage() override
64 {
David Monahan3fb7e102019-08-20 11:25:29 +010065 // If we have enabled Importing, don't manage the tensor
66 if (!m_IsImportEnabled)
67 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010068 ARMNN_ASSERT(m_MemoryGroup != nullptr);
David Monahan3fb7e102019-08-20 11:25:29 +010069 m_MemoryGroup->manage(&m_Tensor);
70 }
telsoa01c577f2c2018-08-31 09:22:23 +010071 }
72
telsoa01c577f2c2018-08-31 09:22:23 +010073 virtual ITensorHandle* GetParent() const override { return nullptr; }
74
telsoa014fcda012018-03-09 14:13:49 +000075 virtual arm_compute::DataType GetDataType() const override
76 {
77 return m_Tensor.info()->data_type();
78 }
79
telsoa01c577f2c2018-08-31 09:22:23 +010080 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
81 {
Jan Eilers3c9e0452020-04-10 13:00:44 +010082 m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup);
telsoa01c577f2c2018-08-31 09:22:23 +010083 }
84
85 virtual const void* Map(bool /* blocking = true */) const override
86 {
87 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
88 }
telsoa01c577f2c2018-08-31 09:22:23 +010089
David Monahan3fb7e102019-08-20 11:25:29 +010090 virtual void Unmap() const override {}
telsoa01c577f2c2018-08-31 09:22:23 +010091
92 TensorShape GetStrides() const override
93 {
94 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
95 }
96
97 TensorShape GetShape() const override
98 {
99 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
100 }
101
David Monahan3fb7e102019-08-20 11:25:29 +0100102 void SetImportFlags(MemorySourceFlags importFlags)
103 {
104 m_ImportFlags = importFlags;
105 }
106
107 MemorySourceFlags GetImportFlags() const override
108 {
109 return m_ImportFlags;
110 }
111
112 void SetImportEnabledFlag(bool importEnabledFlag)
113 {
114 m_IsImportEnabled = importEnabledFlag;
115 }
116
David Monahan0fa10502022-01-13 10:48:33 +0000117 bool CanBeImported(void* memory, MemorySource source) override
118 {
David Monahan3826ab62022-02-21 12:26:16 +0000119 if (source != MemorySource::Malloc || reinterpret_cast<uintptr_t>(memory) % m_TypeAlignment)
David Monahan0fa10502022-01-13 10:48:33 +0000120 {
121 return false;
122 }
123 return true;
124 }
125
David Monahan3fb7e102019-08-20 11:25:29 +0100126 virtual bool Import(void* memory, MemorySource source) override
127 {
128 if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
129 {
130 if (source == MemorySource::Malloc && m_IsImportEnabled)
131 {
David Monahan0fa10502022-01-13 10:48:33 +0000132 if (!CanBeImported(memory, source))
David Monahan3fb7e102019-08-20 11:25:29 +0100133 {
134 throw MemoryImportException("NeonTensorHandle::Import Attempting to import unaligned memory");
135 }
136
137 // m_Tensor not yet Allocated
138 if (!m_Imported && !m_Tensor.buffer())
139 {
140 arm_compute::Status status = m_Tensor.allocator()->import_memory(memory);
141 // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
142 // with the Status error message
143 m_Imported = bool(status);
144 if (!m_Imported)
145 {
146 throw MemoryImportException(status.error_description());
147 }
148 return m_Imported;
149 }
150
151 // m_Tensor.buffer() initially allocated with Allocate().
152 if (!m_Imported && m_Tensor.buffer())
153 {
154 throw MemoryImportException(
155 "NeonTensorHandle::Import Attempting to import on an already allocated tensor");
156 }
157
158 // m_Tensor.buffer() previously imported.
159 if (m_Imported)
160 {
161 arm_compute::Status status = m_Tensor.allocator()->import_memory(memory);
162 // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
163 // with the Status error message
164 m_Imported = bool(status);
165 if (!m_Imported)
166 {
167 throw MemoryImportException(status.error_description());
168 }
169 return m_Imported;
170 }
171 }
Narumol Prangnawarata2493a02020-08-19 14:39:07 +0100172 else
173 {
174 throw MemoryImportException("NeonTensorHandle::Import is disabled");
175 }
176 }
177 else
178 {
179 throw MemoryImportException("NeonTensorHandle::Incorrect import flag");
David Monahan3fb7e102019-08-20 11:25:29 +0100180 }
181 return false;
182 }
183
telsoa014fcda012018-03-09 14:13:49 +0000184private:
David Beck09e2f272018-10-30 11:38:41 +0000185 // Only used for testing
186 void CopyOutTo(void* memory) const override
187 {
188 switch (this->GetDataType())
189 {
190 case arm_compute::DataType::F32:
191 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
192 static_cast<float*>(memory));
193 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000194 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000195 case arm_compute::DataType::QASYMM8:
196 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
197 static_cast<uint8_t*>(memory));
198 break;
Sadik Armagan48f011e2021-04-21 10:50:34 +0100199 case arm_compute::DataType::QSYMM8:
Sadik Armagane5d0b932020-04-09 15:48:44 +0100200 case arm_compute::DataType::QASYMM8_SIGNED:
201 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
202 static_cast<int8_t*>(memory));
203 break;
Narumol Prangnawarat250d3922020-03-30 16:11:04 +0100204 case arm_compute::DataType::BFLOAT16:
205 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
206 static_cast<armnn::BFloat16*>(memory));
207 break;
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100208 case arm_compute::DataType::F16:
209 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
210 static_cast<armnn::Half*>(memory));
211 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100212 case arm_compute::DataType::S16:
213 case arm_compute::DataType::QSYMM16:
214 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
215 static_cast<int16_t*>(memory));
216 break;
James Conroyd47a0642019-09-17 14:22:06 +0100217 case arm_compute::DataType::S32:
218 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
219 static_cast<int32_t*>(memory));
220 break;
David Beck09e2f272018-10-30 11:38:41 +0000221 default:
222 {
223 throw armnn::UnimplementedException();
224 }
225 }
226 }
227
228 // Only used for testing
229 void CopyInFrom(const void* memory) override
230 {
231 switch (this->GetDataType())
232 {
233 case arm_compute::DataType::F32:
234 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
235 this->GetTensor());
236 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000237 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000238 case arm_compute::DataType::QASYMM8:
239 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
240 this->GetTensor());
241 break;
Sadik Armagan48f011e2021-04-21 10:50:34 +0100242 case arm_compute::DataType::QSYMM8:
Sadik Armagane5d0b932020-04-09 15:48:44 +0100243 case arm_compute::DataType::QASYMM8_SIGNED:
244 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
245 this->GetTensor());
246 break;
Narumol Prangnawarat250d3922020-03-30 16:11:04 +0100247 case arm_compute::DataType::BFLOAT16:
248 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::BFloat16*>(memory),
249 this->GetTensor());
250 break;
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100251 case arm_compute::DataType::F16:
252 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
253 this->GetTensor());
254 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100255 case arm_compute::DataType::S16:
256 case arm_compute::DataType::QSYMM16:
257 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
258 this->GetTensor());
259 break;
James Conroyd47a0642019-09-17 14:22:06 +0100260 case arm_compute::DataType::S32:
261 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
262 this->GetTensor());
263 break;
David Beck09e2f272018-10-30 11:38:41 +0000264 default:
265 {
266 throw armnn::UnimplementedException();
267 }
268 }
269 }
270
telsoa014fcda012018-03-09 14:13:49 +0000271 arm_compute::Tensor m_Tensor;
telsoa01c577f2c2018-08-31 09:22:23 +0100272 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
David Monahan3fb7e102019-08-20 11:25:29 +0100273 MemorySourceFlags m_ImportFlags;
274 bool m_Imported;
275 bool m_IsImportEnabled;
Finn Williamsb1aad422021-10-28 19:07:32 +0100276 const uintptr_t m_TypeAlignment;
telsoa014fcda012018-03-09 14:13:49 +0000277};
278
Derek Lambertic81855f2019-06-13 17:34:19 +0100279class NeonSubTensorHandle : public IAclTensorHandle
telsoa014fcda012018-03-09 14:13:49 +0000280{
281public:
Derek Lambertic81855f2019-06-13 17:34:19 +0100282 NeonSubTensorHandle(IAclTensorHandle* parent,
telsoa01c577f2c2018-08-31 09:22:23 +0100283 const arm_compute::TensorShape& shape,
284 const arm_compute::Coordinates& coords)
285 : m_Tensor(&parent->GetTensor(), shape, coords)
telsoa014fcda012018-03-09 14:13:49 +0000286 {
telsoa01c577f2c2018-08-31 09:22:23 +0100287 parentHandle = parent;
telsoa014fcda012018-03-09 14:13:49 +0000288 }
289
290 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
291 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +0100292
293 virtual void Allocate() override {}
294 virtual void Manage() override {}
telsoa014fcda012018-03-09 14:13:49 +0000295
telsoa01c577f2c2018-08-31 09:22:23 +0100296 virtual ITensorHandle* GetParent() const override { return parentHandle; }
297
telsoa014fcda012018-03-09 14:13:49 +0000298 virtual arm_compute::DataType GetDataType() const override
299 {
300 return m_Tensor.info()->data_type();
301 }
302
telsoa01c577f2c2018-08-31 09:22:23 +0100303 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
304
305 virtual const void* Map(bool /* blocking = true */) const override
306 {
307 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
308 }
309 virtual void Unmap() const override {}
310
311 TensorShape GetStrides() const override
312 {
313 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
314 }
315
316 TensorShape GetShape() const override
317 {
318 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
319 }
David Beck09e2f272018-10-30 11:38:41 +0000320
telsoa014fcda012018-03-09 14:13:49 +0000321private:
David Beck09e2f272018-10-30 11:38:41 +0000322 // Only used for testing
323 void CopyOutTo(void* memory) const override
324 {
325 switch (this->GetDataType())
326 {
327 case arm_compute::DataType::F32:
328 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
329 static_cast<float*>(memory));
330 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000331 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000332 case arm_compute::DataType::QASYMM8:
333 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
334 static_cast<uint8_t*>(memory));
335 break;
Sadik Armagan48f011e2021-04-21 10:50:34 +0100336 case arm_compute::DataType::QSYMM8:
Sadik Armagane5d0b932020-04-09 15:48:44 +0100337 case arm_compute::DataType::QASYMM8_SIGNED:
338 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
339 static_cast<int8_t*>(memory));
340 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100341 case arm_compute::DataType::S16:
342 case arm_compute::DataType::QSYMM16:
343 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
344 static_cast<int16_t*>(memory));
345 break;
James Conroyd47a0642019-09-17 14:22:06 +0100346 case arm_compute::DataType::S32:
347 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
348 static_cast<int32_t*>(memory));
349 break;
David Beck09e2f272018-10-30 11:38:41 +0000350 default:
351 {
352 throw armnn::UnimplementedException();
353 }
354 }
355 }
356
357 // Only used for testing
358 void CopyInFrom(const void* memory) override
359 {
360 switch (this->GetDataType())
361 {
362 case arm_compute::DataType::F32:
363 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
364 this->GetTensor());
365 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000366 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000367 case arm_compute::DataType::QASYMM8:
368 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
369 this->GetTensor());
370 break;
Sadik Armagan48f011e2021-04-21 10:50:34 +0100371 case arm_compute::DataType::QSYMM8:
Sadik Armagane5d0b932020-04-09 15:48:44 +0100372 case arm_compute::DataType::QASYMM8_SIGNED:
373 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
374 this->GetTensor());
375 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100376 case arm_compute::DataType::S16:
377 case arm_compute::DataType::QSYMM16:
378 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
379 this->GetTensor());
380 break;
James Conroyd47a0642019-09-17 14:22:06 +0100381 case arm_compute::DataType::S32:
382 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
383 this->GetTensor());
384 break;
David Beck09e2f272018-10-30 11:38:41 +0000385 default:
386 {
387 throw armnn::UnimplementedException();
388 }
389 }
390 }
391
telsoa01c577f2c2018-08-31 09:22:23 +0100392 arm_compute::SubTensor m_Tensor;
393 ITensorHandle* parentHandle = nullptr;
telsoa014fcda012018-03-09 14:13:49 +0000394};
395
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100396} // namespace armnn