blob: f40b5fc2e5962a17ff0b14d7ef32efb0d3633fdb [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
Narumol Prangnawarat250d3922020-03-30 16:11:04 +01007#include <BFloat16.hpp>
Aron Virginas-Tar99836d32019-09-30 16:34:31 +01008#include <Half.hpp>
9
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010010#include <armnn/utility/Assert.hpp>
11
Derek Lambertic81855f2019-06-13 17:34:19 +010012#include <aclCommon/ArmComputeTensorHandle.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000013#include <aclCommon/ArmComputeTensorUtils.hpp>
Jan Eilers3c9e0452020-04-10 13:00:44 +010014#include <armnn/utility/PolymorphicDowncast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000015
telsoa01c577f2c2018-08-31 09:22:23 +010016#include <arm_compute/runtime/MemoryGroup.h>
17#include <arm_compute/runtime/IMemoryGroup.h>
telsoa014fcda012018-03-09 14:13:49 +000018#include <arm_compute/runtime/Tensor.h>
19#include <arm_compute/runtime/SubTensor.h>
20#include <arm_compute/core/TensorShape.h>
21#include <arm_compute/core/Coordinates.h>
22
telsoa014fcda012018-03-09 14:13:49 +000023namespace armnn
24{
25
Derek Lambertic81855f2019-06-13 17:34:19 +010026class NeonTensorHandle : public IAclTensorHandle
telsoa014fcda012018-03-09 14:13:49 +000027{
28public:
29 NeonTensorHandle(const TensorInfo& tensorInfo)
David Monahan3fb7e102019-08-20 11:25:29 +010030 : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc)),
31 m_Imported(false),
Finn Williamsb1aad422021-10-28 19:07:32 +010032 m_IsImportEnabled(false),
33 m_TypeAlignment(GetDataTypeSize(tensorInfo.GetDataType()))
telsoa014fcda012018-03-09 14:13:49 +000034 {
35 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
36 }
37
David Monahan3fb7e102019-08-20 11:25:29 +010038 NeonTensorHandle(const TensorInfo& tensorInfo,
39 DataLayout dataLayout,
40 MemorySourceFlags importFlags = static_cast<MemorySourceFlags>(MemorySource::Malloc))
41 : m_ImportFlags(importFlags),
42 m_Imported(false),
Finn Williamsb1aad422021-10-28 19:07:32 +010043 m_IsImportEnabled(false),
44 m_TypeAlignment(GetDataTypeSize(tensorInfo.GetDataType()))
45
David Monahan3fb7e102019-08-20 11:25:29 +010046
Francis Murtagh351d13d2018-09-24 15:01:18 +010047 {
48 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
49 }
50
telsoa014fcda012018-03-09 14:13:49 +000051 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
52 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +010053
telsoa014fcda012018-03-09 14:13:49 +000054 virtual void Allocate() override
55 {
David Monahan3fb7e102019-08-20 11:25:29 +010056 // If we have enabled Importing, don't Allocate the tensor
57 if (!m_IsImportEnabled)
58 {
59 armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
60 }
telsoa014fcda012018-03-09 14:13:49 +000061 };
62
telsoa01c577f2c2018-08-31 09:22:23 +010063 virtual void Manage() override
64 {
David Monahan3fb7e102019-08-20 11:25:29 +010065 // If we have enabled Importing, don't manage the tensor
66 if (!m_IsImportEnabled)
67 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010068 ARMNN_ASSERT(m_MemoryGroup != nullptr);
David Monahan3fb7e102019-08-20 11:25:29 +010069 m_MemoryGroup->manage(&m_Tensor);
70 }
telsoa01c577f2c2018-08-31 09:22:23 +010071 }
72
telsoa01c577f2c2018-08-31 09:22:23 +010073 virtual ITensorHandle* GetParent() const override { return nullptr; }
74
telsoa014fcda012018-03-09 14:13:49 +000075 virtual arm_compute::DataType GetDataType() const override
76 {
77 return m_Tensor.info()->data_type();
78 }
79
telsoa01c577f2c2018-08-31 09:22:23 +010080 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
81 {
Jan Eilers3c9e0452020-04-10 13:00:44 +010082 m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup);
telsoa01c577f2c2018-08-31 09:22:23 +010083 }
84
85 virtual const void* Map(bool /* blocking = true */) const override
86 {
87 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
88 }
telsoa01c577f2c2018-08-31 09:22:23 +010089
David Monahan3fb7e102019-08-20 11:25:29 +010090 virtual void Unmap() const override {}
telsoa01c577f2c2018-08-31 09:22:23 +010091
92 TensorShape GetStrides() const override
93 {
94 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
95 }
96
97 TensorShape GetShape() const override
98 {
99 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
100 }
101
David Monahan3fb7e102019-08-20 11:25:29 +0100102 void SetImportFlags(MemorySourceFlags importFlags)
103 {
104 m_ImportFlags = importFlags;
105 }
106
107 MemorySourceFlags GetImportFlags() const override
108 {
109 return m_ImportFlags;
110 }
111
112 void SetImportEnabledFlag(bool importEnabledFlag)
113 {
114 m_IsImportEnabled = importEnabledFlag;
115 }
116
David Monahan0fa10502022-01-13 10:48:33 +0000117 bool CanBeImported(void* memory, MemorySource source) override
118 {
119 armnn::IgnoreUnused(source);
120 if (reinterpret_cast<uintptr_t>(memory) % m_TypeAlignment)
121 {
122 return false;
123 }
124 return true;
125 }
126
David Monahan3fb7e102019-08-20 11:25:29 +0100127 virtual bool Import(void* memory, MemorySource source) override
128 {
129 if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
130 {
131 if (source == MemorySource::Malloc && m_IsImportEnabled)
132 {
David Monahan0fa10502022-01-13 10:48:33 +0000133 if (!CanBeImported(memory, source))
David Monahan3fb7e102019-08-20 11:25:29 +0100134 {
135 throw MemoryImportException("NeonTensorHandle::Import Attempting to import unaligned memory");
136 }
137
138 // m_Tensor not yet Allocated
139 if (!m_Imported && !m_Tensor.buffer())
140 {
141 arm_compute::Status status = m_Tensor.allocator()->import_memory(memory);
142 // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
143 // with the Status error message
144 m_Imported = bool(status);
145 if (!m_Imported)
146 {
147 throw MemoryImportException(status.error_description());
148 }
149 return m_Imported;
150 }
151
152 // m_Tensor.buffer() initially allocated with Allocate().
153 if (!m_Imported && m_Tensor.buffer())
154 {
155 throw MemoryImportException(
156 "NeonTensorHandle::Import Attempting to import on an already allocated tensor");
157 }
158
159 // m_Tensor.buffer() previously imported.
160 if (m_Imported)
161 {
162 arm_compute::Status status = m_Tensor.allocator()->import_memory(memory);
163 // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
164 // with the Status error message
165 m_Imported = bool(status);
166 if (!m_Imported)
167 {
168 throw MemoryImportException(status.error_description());
169 }
170 return m_Imported;
171 }
172 }
Narumol Prangnawarata2493a02020-08-19 14:39:07 +0100173 else
174 {
175 throw MemoryImportException("NeonTensorHandle::Import is disabled");
176 }
177 }
178 else
179 {
180 throw MemoryImportException("NeonTensorHandle::Incorrect import flag");
David Monahan3fb7e102019-08-20 11:25:29 +0100181 }
182 return false;
183 }
184
telsoa014fcda012018-03-09 14:13:49 +0000185private:
David Beck09e2f272018-10-30 11:38:41 +0000186 // Only used for testing
187 void CopyOutTo(void* memory) const override
188 {
189 switch (this->GetDataType())
190 {
191 case arm_compute::DataType::F32:
192 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
193 static_cast<float*>(memory));
194 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000195 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000196 case arm_compute::DataType::QASYMM8:
197 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
198 static_cast<uint8_t*>(memory));
199 break;
Sadik Armagan48f011e2021-04-21 10:50:34 +0100200 case arm_compute::DataType::QSYMM8:
Sadik Armagane5d0b932020-04-09 15:48:44 +0100201 case arm_compute::DataType::QASYMM8_SIGNED:
202 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
203 static_cast<int8_t*>(memory));
204 break;
Narumol Prangnawarat250d3922020-03-30 16:11:04 +0100205 case arm_compute::DataType::BFLOAT16:
206 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
207 static_cast<armnn::BFloat16*>(memory));
208 break;
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100209 case arm_compute::DataType::F16:
210 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
211 static_cast<armnn::Half*>(memory));
212 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100213 case arm_compute::DataType::S16:
214 case arm_compute::DataType::QSYMM16:
215 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
216 static_cast<int16_t*>(memory));
217 break;
James Conroyd47a0642019-09-17 14:22:06 +0100218 case arm_compute::DataType::S32:
219 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
220 static_cast<int32_t*>(memory));
221 break;
David Beck09e2f272018-10-30 11:38:41 +0000222 default:
223 {
224 throw armnn::UnimplementedException();
225 }
226 }
227 }
228
229 // Only used for testing
230 void CopyInFrom(const void* memory) override
231 {
232 switch (this->GetDataType())
233 {
234 case arm_compute::DataType::F32:
235 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
236 this->GetTensor());
237 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000238 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000239 case arm_compute::DataType::QASYMM8:
240 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
241 this->GetTensor());
242 break;
Sadik Armagan48f011e2021-04-21 10:50:34 +0100243 case arm_compute::DataType::QSYMM8:
Sadik Armagane5d0b932020-04-09 15:48:44 +0100244 case arm_compute::DataType::QASYMM8_SIGNED:
245 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
246 this->GetTensor());
247 break;
Narumol Prangnawarat250d3922020-03-30 16:11:04 +0100248 case arm_compute::DataType::BFLOAT16:
249 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::BFloat16*>(memory),
250 this->GetTensor());
251 break;
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100252 case arm_compute::DataType::F16:
253 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
254 this->GetTensor());
255 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100256 case arm_compute::DataType::S16:
257 case arm_compute::DataType::QSYMM16:
258 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
259 this->GetTensor());
260 break;
James Conroyd47a0642019-09-17 14:22:06 +0100261 case arm_compute::DataType::S32:
262 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
263 this->GetTensor());
264 break;
David Beck09e2f272018-10-30 11:38:41 +0000265 default:
266 {
267 throw armnn::UnimplementedException();
268 }
269 }
270 }
271
telsoa014fcda012018-03-09 14:13:49 +0000272 arm_compute::Tensor m_Tensor;
telsoa01c577f2c2018-08-31 09:22:23 +0100273 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
David Monahan3fb7e102019-08-20 11:25:29 +0100274 MemorySourceFlags m_ImportFlags;
275 bool m_Imported;
276 bool m_IsImportEnabled;
Finn Williamsb1aad422021-10-28 19:07:32 +0100277 const uintptr_t m_TypeAlignment;
telsoa014fcda012018-03-09 14:13:49 +0000278};
279
Derek Lambertic81855f2019-06-13 17:34:19 +0100280class NeonSubTensorHandle : public IAclTensorHandle
telsoa014fcda012018-03-09 14:13:49 +0000281{
282public:
Derek Lambertic81855f2019-06-13 17:34:19 +0100283 NeonSubTensorHandle(IAclTensorHandle* parent,
telsoa01c577f2c2018-08-31 09:22:23 +0100284 const arm_compute::TensorShape& shape,
285 const arm_compute::Coordinates& coords)
286 : m_Tensor(&parent->GetTensor(), shape, coords)
telsoa014fcda012018-03-09 14:13:49 +0000287 {
telsoa01c577f2c2018-08-31 09:22:23 +0100288 parentHandle = parent;
telsoa014fcda012018-03-09 14:13:49 +0000289 }
290
291 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
292 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +0100293
294 virtual void Allocate() override {}
295 virtual void Manage() override {}
telsoa014fcda012018-03-09 14:13:49 +0000296
telsoa01c577f2c2018-08-31 09:22:23 +0100297 virtual ITensorHandle* GetParent() const override { return parentHandle; }
298
telsoa014fcda012018-03-09 14:13:49 +0000299 virtual arm_compute::DataType GetDataType() const override
300 {
301 return m_Tensor.info()->data_type();
302 }
303
telsoa01c577f2c2018-08-31 09:22:23 +0100304 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
305
306 virtual const void* Map(bool /* blocking = true */) const override
307 {
308 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
309 }
310 virtual void Unmap() const override {}
311
312 TensorShape GetStrides() const override
313 {
314 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
315 }
316
317 TensorShape GetShape() const override
318 {
319 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
320 }
David Beck09e2f272018-10-30 11:38:41 +0000321
telsoa014fcda012018-03-09 14:13:49 +0000322private:
David Beck09e2f272018-10-30 11:38:41 +0000323 // Only used for testing
324 void CopyOutTo(void* memory) const override
325 {
326 switch (this->GetDataType())
327 {
328 case arm_compute::DataType::F32:
329 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
330 static_cast<float*>(memory));
331 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000332 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000333 case arm_compute::DataType::QASYMM8:
334 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
335 static_cast<uint8_t*>(memory));
336 break;
Sadik Armagan48f011e2021-04-21 10:50:34 +0100337 case arm_compute::DataType::QSYMM8:
Sadik Armagane5d0b932020-04-09 15:48:44 +0100338 case arm_compute::DataType::QASYMM8_SIGNED:
339 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
340 static_cast<int8_t*>(memory));
341 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100342 case arm_compute::DataType::S16:
343 case arm_compute::DataType::QSYMM16:
344 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
345 static_cast<int16_t*>(memory));
346 break;
James Conroyd47a0642019-09-17 14:22:06 +0100347 case arm_compute::DataType::S32:
348 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
349 static_cast<int32_t*>(memory));
350 break;
David Beck09e2f272018-10-30 11:38:41 +0000351 default:
352 {
353 throw armnn::UnimplementedException();
354 }
355 }
356 }
357
358 // Only used for testing
359 void CopyInFrom(const void* memory) override
360 {
361 switch (this->GetDataType())
362 {
363 case arm_compute::DataType::F32:
364 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
365 this->GetTensor());
366 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000367 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000368 case arm_compute::DataType::QASYMM8:
369 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
370 this->GetTensor());
371 break;
Sadik Armagan48f011e2021-04-21 10:50:34 +0100372 case arm_compute::DataType::QSYMM8:
Sadik Armagane5d0b932020-04-09 15:48:44 +0100373 case arm_compute::DataType::QASYMM8_SIGNED:
374 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
375 this->GetTensor());
376 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100377 case arm_compute::DataType::S16:
378 case arm_compute::DataType::QSYMM16:
379 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
380 this->GetTensor());
381 break;
James Conroyd47a0642019-09-17 14:22:06 +0100382 case arm_compute::DataType::S32:
383 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
384 this->GetTensor());
385 break;
David Beck09e2f272018-10-30 11:38:41 +0000386 default:
387 {
388 throw armnn::UnimplementedException();
389 }
390 }
391 }
392
telsoa01c577f2c2018-08-31 09:22:23 +0100393 arm_compute::SubTensor m_Tensor;
394 ITensorHandle* parentHandle = nullptr;
telsoa014fcda012018-03-09 14:13:49 +0000395};
396
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100397} // namespace armnn