blob: fcae77cdaacce50f3ba349b8646a2bd8da9960bd [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
Narumol Prangnawarat250d3922020-03-30 16:11:04 +01007#include <BFloat16.hpp>
Aron Virginas-Tar99836d32019-09-30 16:34:31 +01008#include <Half.hpp>
9
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010010#include <armnn/utility/Assert.hpp>
11
Derek Lambertic81855f2019-06-13 17:34:19 +010012#include <aclCommon/ArmComputeTensorHandle.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000013#include <aclCommon/ArmComputeTensorUtils.hpp>
Jan Eilers3c9e0452020-04-10 13:00:44 +010014#include <armnn/utility/PolymorphicDowncast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000015
telsoa01c577f2c2018-08-31 09:22:23 +010016#include <arm_compute/runtime/MemoryGroup.h>
17#include <arm_compute/runtime/IMemoryGroup.h>
telsoa014fcda012018-03-09 14:13:49 +000018#include <arm_compute/runtime/Tensor.h>
19#include <arm_compute/runtime/SubTensor.h>
20#include <arm_compute/core/TensorShape.h>
21#include <arm_compute/core/Coordinates.h>
22
telsoa014fcda012018-03-09 14:13:49 +000023namespace armnn
24{
25
Derek Lambertic81855f2019-06-13 17:34:19 +010026class NeonTensorHandle : public IAclTensorHandle
telsoa014fcda012018-03-09 14:13:49 +000027{
28public:
29 NeonTensorHandle(const TensorInfo& tensorInfo)
David Monahan3fb7e102019-08-20 11:25:29 +010030 : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc)),
31 m_Imported(false),
Finn Williamsb1aad422021-10-28 19:07:32 +010032 m_IsImportEnabled(false),
33 m_TypeAlignment(GetDataTypeSize(tensorInfo.GetDataType()))
telsoa014fcda012018-03-09 14:13:49 +000034 {
35 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
36 }
37
David Monahan3fb7e102019-08-20 11:25:29 +010038 NeonTensorHandle(const TensorInfo& tensorInfo,
39 DataLayout dataLayout,
40 MemorySourceFlags importFlags = static_cast<MemorySourceFlags>(MemorySource::Malloc))
41 : m_ImportFlags(importFlags),
42 m_Imported(false),
Finn Williamsb1aad422021-10-28 19:07:32 +010043 m_IsImportEnabled(false),
44 m_TypeAlignment(GetDataTypeSize(tensorInfo.GetDataType()))
45
David Monahan3fb7e102019-08-20 11:25:29 +010046
Francis Murtagh351d13d2018-09-24 15:01:18 +010047 {
48 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
49 }
50
telsoa014fcda012018-03-09 14:13:49 +000051 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
52 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +010053
telsoa014fcda012018-03-09 14:13:49 +000054 virtual void Allocate() override
55 {
David Monahan3fb7e102019-08-20 11:25:29 +010056 // If we have enabled Importing, don't Allocate the tensor
57 if (!m_IsImportEnabled)
58 {
59 armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
60 }
telsoa014fcda012018-03-09 14:13:49 +000061 };
62
telsoa01c577f2c2018-08-31 09:22:23 +010063 virtual void Manage() override
64 {
David Monahan3fb7e102019-08-20 11:25:29 +010065 // If we have enabled Importing, don't manage the tensor
66 if (!m_IsImportEnabled)
67 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010068 ARMNN_ASSERT(m_MemoryGroup != nullptr);
David Monahan3fb7e102019-08-20 11:25:29 +010069 m_MemoryGroup->manage(&m_Tensor);
70 }
telsoa01c577f2c2018-08-31 09:22:23 +010071 }
72
telsoa01c577f2c2018-08-31 09:22:23 +010073 virtual ITensorHandle* GetParent() const override { return nullptr; }
74
telsoa014fcda012018-03-09 14:13:49 +000075 virtual arm_compute::DataType GetDataType() const override
76 {
77 return m_Tensor.info()->data_type();
78 }
79
telsoa01c577f2c2018-08-31 09:22:23 +010080 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
81 {
Jan Eilers3c9e0452020-04-10 13:00:44 +010082 m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup);
telsoa01c577f2c2018-08-31 09:22:23 +010083 }
84
85 virtual const void* Map(bool /* blocking = true */) const override
86 {
87 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
88 }
telsoa01c577f2c2018-08-31 09:22:23 +010089
David Monahan3fb7e102019-08-20 11:25:29 +010090 virtual void Unmap() const override {}
telsoa01c577f2c2018-08-31 09:22:23 +010091
92 TensorShape GetStrides() const override
93 {
94 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
95 }
96
97 TensorShape GetShape() const override
98 {
99 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
100 }
101
David Monahan3fb7e102019-08-20 11:25:29 +0100102 void SetImportFlags(MemorySourceFlags importFlags)
103 {
104 m_ImportFlags = importFlags;
105 }
106
107 MemorySourceFlags GetImportFlags() const override
108 {
109 return m_ImportFlags;
110 }
111
112 void SetImportEnabledFlag(bool importEnabledFlag)
113 {
114 m_IsImportEnabled = importEnabledFlag;
115 }
116
David Monahan0fa10502022-01-13 10:48:33 +0000117 bool CanBeImported(void* memory, MemorySource source) override
118 {
David Monahan3826ab62022-02-21 12:26:16 +0000119 if (source != MemorySource::Malloc || reinterpret_cast<uintptr_t>(memory) % m_TypeAlignment)
David Monahan0fa10502022-01-13 10:48:33 +0000120 {
121 return false;
122 }
123 return true;
124 }
125
David Monahan3fb7e102019-08-20 11:25:29 +0100126 virtual bool Import(void* memory, MemorySource source) override
127 {
128 if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
129 {
130 if (source == MemorySource::Malloc && m_IsImportEnabled)
131 {
David Monahan0fa10502022-01-13 10:48:33 +0000132 if (!CanBeImported(memory, source))
David Monahan3fb7e102019-08-20 11:25:29 +0100133 {
134 throw MemoryImportException("NeonTensorHandle::Import Attempting to import unaligned memory");
135 }
136
137 // m_Tensor not yet Allocated
138 if (!m_Imported && !m_Tensor.buffer())
139 {
140 arm_compute::Status status = m_Tensor.allocator()->import_memory(memory);
141 // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
142 // with the Status error message
143 m_Imported = bool(status);
144 if (!m_Imported)
145 {
146 throw MemoryImportException(status.error_description());
147 }
148 return m_Imported;
149 }
150
151 // m_Tensor.buffer() initially allocated with Allocate().
152 if (!m_Imported && m_Tensor.buffer())
153 {
154 throw MemoryImportException(
155 "NeonTensorHandle::Import Attempting to import on an already allocated tensor");
156 }
157
158 // m_Tensor.buffer() previously imported.
159 if (m_Imported)
160 {
161 arm_compute::Status status = m_Tensor.allocator()->import_memory(memory);
162 // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
163 // with the Status error message
164 m_Imported = bool(status);
165 if (!m_Imported)
166 {
167 throw MemoryImportException(status.error_description());
168 }
169 return m_Imported;
170 }
171 }
Narumol Prangnawarata2493a02020-08-19 14:39:07 +0100172 else
173 {
174 throw MemoryImportException("NeonTensorHandle::Import is disabled");
175 }
176 }
177 else
178 {
179 throw MemoryImportException("NeonTensorHandle::Incorrect import flag");
David Monahan3fb7e102019-08-20 11:25:29 +0100180 }
181 return false;
182 }
183
telsoa014fcda012018-03-09 14:13:49 +0000184private:
David Beck09e2f272018-10-30 11:38:41 +0000185 // Only used for testing
186 void CopyOutTo(void* memory) const override
187 {
188 switch (this->GetDataType())
189 {
190 case arm_compute::DataType::F32:
191 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
192 static_cast<float*>(memory));
193 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000194 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000195 case arm_compute::DataType::QASYMM8:
196 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
197 static_cast<uint8_t*>(memory));
198 break;
Sadik Armagan48f011e2021-04-21 10:50:34 +0100199 case arm_compute::DataType::QSYMM8:
Sadik Armagane5d0b932020-04-09 15:48:44 +0100200 case arm_compute::DataType::QASYMM8_SIGNED:
201 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
202 static_cast<int8_t*>(memory));
203 break;
Narumol Prangnawarat250d3922020-03-30 16:11:04 +0100204 case arm_compute::DataType::BFLOAT16:
205 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
206 static_cast<armnn::BFloat16*>(memory));
207 break;
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100208 case arm_compute::DataType::F16:
209 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
210 static_cast<armnn::Half*>(memory));
211 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100212 case arm_compute::DataType::S16:
213 case arm_compute::DataType::QSYMM16:
214 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
215 static_cast<int16_t*>(memory));
216 break;
James Conroyd47a0642019-09-17 14:22:06 +0100217 case arm_compute::DataType::S32:
218 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
219 static_cast<int32_t*>(memory));
220 break;
David Beck09e2f272018-10-30 11:38:41 +0000221 default:
222 {
223 throw armnn::UnimplementedException();
224 }
225 }
226 }
227
228 // Only used for testing
229 void CopyInFrom(const void* memory) override
230 {
231 switch (this->GetDataType())
232 {
233 case arm_compute::DataType::F32:
234 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
235 this->GetTensor());
236 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000237 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000238 case arm_compute::DataType::QASYMM8:
239 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
240 this->GetTensor());
241 break;
Sadik Armagan48f011e2021-04-21 10:50:34 +0100242 case arm_compute::DataType::QSYMM8:
Sadik Armagane5d0b932020-04-09 15:48:44 +0100243 case arm_compute::DataType::QASYMM8_SIGNED:
Cathal Corbett06902652022-04-14 17:55:11 +0100244 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
Sadik Armagane5d0b932020-04-09 15:48:44 +0100245 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
246 this->GetTensor());
247 break;
Narumol Prangnawarat250d3922020-03-30 16:11:04 +0100248 case arm_compute::DataType::BFLOAT16:
249 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::BFloat16*>(memory),
250 this->GetTensor());
251 break;
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100252 case arm_compute::DataType::F16:
253 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
254 this->GetTensor());
255 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100256 case arm_compute::DataType::S16:
257 case arm_compute::DataType::QSYMM16:
258 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
259 this->GetTensor());
260 break;
James Conroyd47a0642019-09-17 14:22:06 +0100261 case arm_compute::DataType::S32:
262 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
263 this->GetTensor());
264 break;
David Beck09e2f272018-10-30 11:38:41 +0000265 default:
266 {
267 throw armnn::UnimplementedException();
268 }
269 }
270 }
271
telsoa014fcda012018-03-09 14:13:49 +0000272 arm_compute::Tensor m_Tensor;
telsoa01c577f2c2018-08-31 09:22:23 +0100273 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
David Monahan3fb7e102019-08-20 11:25:29 +0100274 MemorySourceFlags m_ImportFlags;
275 bool m_Imported;
276 bool m_IsImportEnabled;
Finn Williamsb1aad422021-10-28 19:07:32 +0100277 const uintptr_t m_TypeAlignment;
telsoa014fcda012018-03-09 14:13:49 +0000278};
279
Derek Lambertic81855f2019-06-13 17:34:19 +0100280class NeonSubTensorHandle : public IAclTensorHandle
telsoa014fcda012018-03-09 14:13:49 +0000281{
282public:
Derek Lambertic81855f2019-06-13 17:34:19 +0100283 NeonSubTensorHandle(IAclTensorHandle* parent,
telsoa01c577f2c2018-08-31 09:22:23 +0100284 const arm_compute::TensorShape& shape,
285 const arm_compute::Coordinates& coords)
286 : m_Tensor(&parent->GetTensor(), shape, coords)
telsoa014fcda012018-03-09 14:13:49 +0000287 {
telsoa01c577f2c2018-08-31 09:22:23 +0100288 parentHandle = parent;
telsoa014fcda012018-03-09 14:13:49 +0000289 }
290
291 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
292 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +0100293
294 virtual void Allocate() override {}
295 virtual void Manage() override {}
telsoa014fcda012018-03-09 14:13:49 +0000296
telsoa01c577f2c2018-08-31 09:22:23 +0100297 virtual ITensorHandle* GetParent() const override { return parentHandle; }
298
telsoa014fcda012018-03-09 14:13:49 +0000299 virtual arm_compute::DataType GetDataType() const override
300 {
301 return m_Tensor.info()->data_type();
302 }
303
telsoa01c577f2c2018-08-31 09:22:23 +0100304 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
305
306 virtual const void* Map(bool /* blocking = true */) const override
307 {
308 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
309 }
310 virtual void Unmap() const override {}
311
312 TensorShape GetStrides() const override
313 {
314 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
315 }
316
317 TensorShape GetShape() const override
318 {
319 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
320 }
David Beck09e2f272018-10-30 11:38:41 +0000321
telsoa014fcda012018-03-09 14:13:49 +0000322private:
David Beck09e2f272018-10-30 11:38:41 +0000323 // Only used for testing
324 void CopyOutTo(void* memory) const override
325 {
326 switch (this->GetDataType())
327 {
328 case arm_compute::DataType::F32:
329 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
330 static_cast<float*>(memory));
331 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000332 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000333 case arm_compute::DataType::QASYMM8:
334 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
335 static_cast<uint8_t*>(memory));
336 break;
Sadik Armagan48f011e2021-04-21 10:50:34 +0100337 case arm_compute::DataType::QSYMM8:
Sadik Armagane5d0b932020-04-09 15:48:44 +0100338 case arm_compute::DataType::QASYMM8_SIGNED:
339 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
340 static_cast<int8_t*>(memory));
341 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100342 case arm_compute::DataType::S16:
343 case arm_compute::DataType::QSYMM16:
344 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
345 static_cast<int16_t*>(memory));
346 break;
James Conroyd47a0642019-09-17 14:22:06 +0100347 case arm_compute::DataType::S32:
348 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
349 static_cast<int32_t*>(memory));
350 break;
David Beck09e2f272018-10-30 11:38:41 +0000351 default:
352 {
353 throw armnn::UnimplementedException();
354 }
355 }
356 }
357
358 // Only used for testing
359 void CopyInFrom(const void* memory) override
360 {
361 switch (this->GetDataType())
362 {
363 case arm_compute::DataType::F32:
364 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
365 this->GetTensor());
366 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000367 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000368 case arm_compute::DataType::QASYMM8:
369 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
370 this->GetTensor());
371 break;
Sadik Armagan48f011e2021-04-21 10:50:34 +0100372 case arm_compute::DataType::QSYMM8:
Sadik Armagane5d0b932020-04-09 15:48:44 +0100373 case arm_compute::DataType::QASYMM8_SIGNED:
374 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
375 this->GetTensor());
376 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100377 case arm_compute::DataType::S16:
378 case arm_compute::DataType::QSYMM16:
379 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
380 this->GetTensor());
381 break;
James Conroyd47a0642019-09-17 14:22:06 +0100382 case arm_compute::DataType::S32:
383 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
384 this->GetTensor());
385 break;
David Beck09e2f272018-10-30 11:38:41 +0000386 default:
387 {
388 throw armnn::UnimplementedException();
389 }
390 }
391 }
392
telsoa01c577f2c2018-08-31 09:22:23 +0100393 arm_compute::SubTensor m_Tensor;
394 ITensorHandle* parentHandle = nullptr;
telsoa014fcda012018-03-09 14:13:49 +0000395};
396
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100397} // namespace armnn