blob: dd4c2572f99782bea6aa33ed1915c4fc2a36956f [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
Narumol Prangnawarat250d3922020-03-30 16:11:04 +01007#include <BFloat16.hpp>
Aron Virginas-Tar99836d32019-09-30 16:34:31 +01008#include <Half.hpp>
9
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010010#include <armnn/utility/Assert.hpp>
11
Derek Lambertic81855f2019-06-13 17:34:19 +010012#include <aclCommon/ArmComputeTensorHandle.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000013#include <aclCommon/ArmComputeTensorUtils.hpp>
Jan Eilers3c9e0452020-04-10 13:00:44 +010014#include <armnn/utility/PolymorphicDowncast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000015
telsoa01c577f2c2018-08-31 09:22:23 +010016#include <arm_compute/runtime/MemoryGroup.h>
17#include <arm_compute/runtime/IMemoryGroup.h>
telsoa014fcda012018-03-09 14:13:49 +000018#include <arm_compute/runtime/Tensor.h>
19#include <arm_compute/runtime/SubTensor.h>
20#include <arm_compute/core/TensorShape.h>
21#include <arm_compute/core/Coordinates.h>
22
telsoa014fcda012018-03-09 14:13:49 +000023namespace armnn
24{
25
Derek Lambertic81855f2019-06-13 17:34:19 +010026class NeonTensorHandle : public IAclTensorHandle
telsoa014fcda012018-03-09 14:13:49 +000027{
28public:
29 NeonTensorHandle(const TensorInfo& tensorInfo)
David Monahan3fb7e102019-08-20 11:25:29 +010030 : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc)),
31 m_Imported(false),
Finn Williamsb1aad422021-10-28 19:07:32 +010032 m_IsImportEnabled(false),
33 m_TypeAlignment(GetDataTypeSize(tensorInfo.GetDataType()))
telsoa014fcda012018-03-09 14:13:49 +000034 {
35 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
36 }
37
David Monahan3fb7e102019-08-20 11:25:29 +010038 NeonTensorHandle(const TensorInfo& tensorInfo,
39 DataLayout dataLayout,
40 MemorySourceFlags importFlags = static_cast<MemorySourceFlags>(MemorySource::Malloc))
41 : m_ImportFlags(importFlags),
42 m_Imported(false),
Finn Williamsb1aad422021-10-28 19:07:32 +010043 m_IsImportEnabled(false),
44 m_TypeAlignment(GetDataTypeSize(tensorInfo.GetDataType()))
45
David Monahan3fb7e102019-08-20 11:25:29 +010046
Francis Murtagh351d13d2018-09-24 15:01:18 +010047 {
48 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
49 }
50
telsoa014fcda012018-03-09 14:13:49 +000051 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
52 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +010053
telsoa014fcda012018-03-09 14:13:49 +000054 virtual void Allocate() override
55 {
David Monahan3fb7e102019-08-20 11:25:29 +010056 // If we have enabled Importing, don't Allocate the tensor
57 if (!m_IsImportEnabled)
58 {
59 armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
60 }
telsoa014fcda012018-03-09 14:13:49 +000061 };
62
telsoa01c577f2c2018-08-31 09:22:23 +010063 virtual void Manage() override
64 {
David Monahan3fb7e102019-08-20 11:25:29 +010065 // If we have enabled Importing, don't manage the tensor
66 if (!m_IsImportEnabled)
67 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010068 ARMNN_ASSERT(m_MemoryGroup != nullptr);
David Monahan3fb7e102019-08-20 11:25:29 +010069 m_MemoryGroup->manage(&m_Tensor);
70 }
telsoa01c577f2c2018-08-31 09:22:23 +010071 }
72
telsoa01c577f2c2018-08-31 09:22:23 +010073 virtual ITensorHandle* GetParent() const override { return nullptr; }
74
telsoa014fcda012018-03-09 14:13:49 +000075 virtual arm_compute::DataType GetDataType() const override
76 {
77 return m_Tensor.info()->data_type();
78 }
79
telsoa01c577f2c2018-08-31 09:22:23 +010080 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
81 {
Jan Eilers3c9e0452020-04-10 13:00:44 +010082 m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup);
telsoa01c577f2c2018-08-31 09:22:23 +010083 }
84
85 virtual const void* Map(bool /* blocking = true */) const override
86 {
87 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
88 }
telsoa01c577f2c2018-08-31 09:22:23 +010089
David Monahan3fb7e102019-08-20 11:25:29 +010090 virtual void Unmap() const override {}
telsoa01c577f2c2018-08-31 09:22:23 +010091
92 TensorShape GetStrides() const override
93 {
94 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
95 }
96
97 TensorShape GetShape() const override
98 {
99 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
100 }
101
David Monahan3fb7e102019-08-20 11:25:29 +0100102 void SetImportFlags(MemorySourceFlags importFlags)
103 {
104 m_ImportFlags = importFlags;
105 }
106
107 MemorySourceFlags GetImportFlags() const override
108 {
109 return m_ImportFlags;
110 }
111
112 void SetImportEnabledFlag(bool importEnabledFlag)
113 {
114 m_IsImportEnabled = importEnabledFlag;
115 }
116
117 virtual bool Import(void* memory, MemorySource source) override
118 {
119 if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
120 {
121 if (source == MemorySource::Malloc && m_IsImportEnabled)
122 {
Finn Williamsb1aad422021-10-28 19:07:32 +0100123 if (reinterpret_cast<uintptr_t>(memory) % m_TypeAlignment)
David Monahan3fb7e102019-08-20 11:25:29 +0100124 {
125 throw MemoryImportException("NeonTensorHandle::Import Attempting to import unaligned memory");
126 }
127
128 // m_Tensor not yet Allocated
129 if (!m_Imported && !m_Tensor.buffer())
130 {
131 arm_compute::Status status = m_Tensor.allocator()->import_memory(memory);
132 // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
133 // with the Status error message
134 m_Imported = bool(status);
135 if (!m_Imported)
136 {
137 throw MemoryImportException(status.error_description());
138 }
139 return m_Imported;
140 }
141
142 // m_Tensor.buffer() initially allocated with Allocate().
143 if (!m_Imported && m_Tensor.buffer())
144 {
145 throw MemoryImportException(
146 "NeonTensorHandle::Import Attempting to import on an already allocated tensor");
147 }
148
149 // m_Tensor.buffer() previously imported.
150 if (m_Imported)
151 {
152 arm_compute::Status status = m_Tensor.allocator()->import_memory(memory);
153 // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
154 // with the Status error message
155 m_Imported = bool(status);
156 if (!m_Imported)
157 {
158 throw MemoryImportException(status.error_description());
159 }
160 return m_Imported;
161 }
162 }
Narumol Prangnawarata2493a02020-08-19 14:39:07 +0100163 else
164 {
165 throw MemoryImportException("NeonTensorHandle::Import is disabled");
166 }
167 }
168 else
169 {
170 throw MemoryImportException("NeonTensorHandle::Incorrect import flag");
David Monahan3fb7e102019-08-20 11:25:29 +0100171 }
172 return false;
173 }
174
telsoa014fcda012018-03-09 14:13:49 +0000175private:
David Beck09e2f272018-10-30 11:38:41 +0000176 // Only used for testing
177 void CopyOutTo(void* memory) const override
178 {
179 switch (this->GetDataType())
180 {
181 case arm_compute::DataType::F32:
182 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
183 static_cast<float*>(memory));
184 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000185 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000186 case arm_compute::DataType::QASYMM8:
187 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
188 static_cast<uint8_t*>(memory));
189 break;
Sadik Armagan48f011e2021-04-21 10:50:34 +0100190 case arm_compute::DataType::QSYMM8:
Sadik Armagane5d0b932020-04-09 15:48:44 +0100191 case arm_compute::DataType::QASYMM8_SIGNED:
192 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
193 static_cast<int8_t*>(memory));
194 break;
Narumol Prangnawarat250d3922020-03-30 16:11:04 +0100195 case arm_compute::DataType::BFLOAT16:
196 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
197 static_cast<armnn::BFloat16*>(memory));
198 break;
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100199 case arm_compute::DataType::F16:
200 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
201 static_cast<armnn::Half*>(memory));
202 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100203 case arm_compute::DataType::S16:
204 case arm_compute::DataType::QSYMM16:
205 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
206 static_cast<int16_t*>(memory));
207 break;
James Conroyd47a0642019-09-17 14:22:06 +0100208 case arm_compute::DataType::S32:
209 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
210 static_cast<int32_t*>(memory));
211 break;
David Beck09e2f272018-10-30 11:38:41 +0000212 default:
213 {
214 throw armnn::UnimplementedException();
215 }
216 }
217 }
218
219 // Only used for testing
220 void CopyInFrom(const void* memory) override
221 {
222 switch (this->GetDataType())
223 {
224 case arm_compute::DataType::F32:
225 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
226 this->GetTensor());
227 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000228 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000229 case arm_compute::DataType::QASYMM8:
230 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
231 this->GetTensor());
232 break;
Sadik Armagan48f011e2021-04-21 10:50:34 +0100233 case arm_compute::DataType::QSYMM8:
Sadik Armagane5d0b932020-04-09 15:48:44 +0100234 case arm_compute::DataType::QASYMM8_SIGNED:
235 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
236 this->GetTensor());
237 break;
Narumol Prangnawarat250d3922020-03-30 16:11:04 +0100238 case arm_compute::DataType::BFLOAT16:
239 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::BFloat16*>(memory),
240 this->GetTensor());
241 break;
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100242 case arm_compute::DataType::F16:
243 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
244 this->GetTensor());
245 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100246 case arm_compute::DataType::S16:
247 case arm_compute::DataType::QSYMM16:
248 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
249 this->GetTensor());
250 break;
James Conroyd47a0642019-09-17 14:22:06 +0100251 case arm_compute::DataType::S32:
252 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
253 this->GetTensor());
254 break;
David Beck09e2f272018-10-30 11:38:41 +0000255 default:
256 {
257 throw armnn::UnimplementedException();
258 }
259 }
260 }
261
telsoa014fcda012018-03-09 14:13:49 +0000262 arm_compute::Tensor m_Tensor;
telsoa01c577f2c2018-08-31 09:22:23 +0100263 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
David Monahan3fb7e102019-08-20 11:25:29 +0100264 MemorySourceFlags m_ImportFlags;
265 bool m_Imported;
266 bool m_IsImportEnabled;
Finn Williamsb1aad422021-10-28 19:07:32 +0100267 const uintptr_t m_TypeAlignment;
telsoa014fcda012018-03-09 14:13:49 +0000268};
269
Derek Lambertic81855f2019-06-13 17:34:19 +0100270class NeonSubTensorHandle : public IAclTensorHandle
telsoa014fcda012018-03-09 14:13:49 +0000271{
272public:
Derek Lambertic81855f2019-06-13 17:34:19 +0100273 NeonSubTensorHandle(IAclTensorHandle* parent,
telsoa01c577f2c2018-08-31 09:22:23 +0100274 const arm_compute::TensorShape& shape,
275 const arm_compute::Coordinates& coords)
276 : m_Tensor(&parent->GetTensor(), shape, coords)
telsoa014fcda012018-03-09 14:13:49 +0000277 {
telsoa01c577f2c2018-08-31 09:22:23 +0100278 parentHandle = parent;
telsoa014fcda012018-03-09 14:13:49 +0000279 }
280
281 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
282 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +0100283
284 virtual void Allocate() override {}
285 virtual void Manage() override {}
telsoa014fcda012018-03-09 14:13:49 +0000286
telsoa01c577f2c2018-08-31 09:22:23 +0100287 virtual ITensorHandle* GetParent() const override { return parentHandle; }
288
telsoa014fcda012018-03-09 14:13:49 +0000289 virtual arm_compute::DataType GetDataType() const override
290 {
291 return m_Tensor.info()->data_type();
292 }
293
telsoa01c577f2c2018-08-31 09:22:23 +0100294 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
295
296 virtual const void* Map(bool /* blocking = true */) const override
297 {
298 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
299 }
300 virtual void Unmap() const override {}
301
302 TensorShape GetStrides() const override
303 {
304 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
305 }
306
307 TensorShape GetShape() const override
308 {
309 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
310 }
David Beck09e2f272018-10-30 11:38:41 +0000311
telsoa014fcda012018-03-09 14:13:49 +0000312private:
David Beck09e2f272018-10-30 11:38:41 +0000313 // Only used for testing
314 void CopyOutTo(void* memory) const override
315 {
316 switch (this->GetDataType())
317 {
318 case arm_compute::DataType::F32:
319 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
320 static_cast<float*>(memory));
321 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000322 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000323 case arm_compute::DataType::QASYMM8:
324 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
325 static_cast<uint8_t*>(memory));
326 break;
Sadik Armagan48f011e2021-04-21 10:50:34 +0100327 case arm_compute::DataType::QSYMM8:
Sadik Armagane5d0b932020-04-09 15:48:44 +0100328 case arm_compute::DataType::QASYMM8_SIGNED:
329 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
330 static_cast<int8_t*>(memory));
331 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100332 case arm_compute::DataType::S16:
333 case arm_compute::DataType::QSYMM16:
334 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
335 static_cast<int16_t*>(memory));
336 break;
James Conroyd47a0642019-09-17 14:22:06 +0100337 case arm_compute::DataType::S32:
338 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
339 static_cast<int32_t*>(memory));
340 break;
David Beck09e2f272018-10-30 11:38:41 +0000341 default:
342 {
343 throw armnn::UnimplementedException();
344 }
345 }
346 }
347
348 // Only used for testing
349 void CopyInFrom(const void* memory) override
350 {
351 switch (this->GetDataType())
352 {
353 case arm_compute::DataType::F32:
354 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
355 this->GetTensor());
356 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000357 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000358 case arm_compute::DataType::QASYMM8:
359 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
360 this->GetTensor());
361 break;
Sadik Armagan48f011e2021-04-21 10:50:34 +0100362 case arm_compute::DataType::QSYMM8:
Sadik Armagane5d0b932020-04-09 15:48:44 +0100363 case arm_compute::DataType::QASYMM8_SIGNED:
364 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
365 this->GetTensor());
366 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100367 case arm_compute::DataType::S16:
368 case arm_compute::DataType::QSYMM16:
369 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
370 this->GetTensor());
371 break;
James Conroyd47a0642019-09-17 14:22:06 +0100372 case arm_compute::DataType::S32:
373 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
374 this->GetTensor());
375 break;
David Beck09e2f272018-10-30 11:38:41 +0000376 default:
377 {
378 throw armnn::UnimplementedException();
379 }
380 }
381 }
382
telsoa01c577f2c2018-08-31 09:22:23 +0100383 arm_compute::SubTensor m_Tensor;
384 ITensorHandle* parentHandle = nullptr;
telsoa014fcda012018-03-09 14:13:49 +0000385};
386
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100387} // namespace armnn