blob: be5bd45956d94312b104f72780495f18ab70c9b3 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
Narumol Prangnawarat250d3922020-03-30 16:11:04 +01007#include <BFloat16.hpp>
Aron Virginas-Tar99836d32019-09-30 16:34:31 +01008#include <Half.hpp>
9
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010010#include <armnn/utility/Assert.hpp>
11
Derek Lambertic81855f2019-06-13 17:34:19 +010012#include <aclCommon/ArmComputeTensorHandle.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000013#include <aclCommon/ArmComputeTensorUtils.hpp>
Jan Eilers3c9e0452020-04-10 13:00:44 +010014#include <armnn/utility/PolymorphicDowncast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000015
telsoa01c577f2c2018-08-31 09:22:23 +010016#include <arm_compute/runtime/MemoryGroup.h>
17#include <arm_compute/runtime/IMemoryGroup.h>
telsoa014fcda012018-03-09 14:13:49 +000018#include <arm_compute/runtime/Tensor.h>
19#include <arm_compute/runtime/SubTensor.h>
20#include <arm_compute/core/TensorShape.h>
21#include <arm_compute/core/Coordinates.h>
22
telsoa014fcda012018-03-09 14:13:49 +000023namespace armnn
24{
25
Derek Lambertic81855f2019-06-13 17:34:19 +010026class NeonTensorHandle : public IAclTensorHandle
telsoa014fcda012018-03-09 14:13:49 +000027{
28public:
29 NeonTensorHandle(const TensorInfo& tensorInfo)
David Monahan3fb7e102019-08-20 11:25:29 +010030 : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc)),
31 m_Imported(false),
32 m_IsImportEnabled(false)
telsoa014fcda012018-03-09 14:13:49 +000033 {
34 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
35 }
36
David Monahan3fb7e102019-08-20 11:25:29 +010037 NeonTensorHandle(const TensorInfo& tensorInfo,
38 DataLayout dataLayout,
39 MemorySourceFlags importFlags = static_cast<MemorySourceFlags>(MemorySource::Malloc))
40 : m_ImportFlags(importFlags),
41 m_Imported(false),
42 m_IsImportEnabled(false)
43
Francis Murtagh351d13d2018-09-24 15:01:18 +010044 {
45 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
46 }
47
telsoa014fcda012018-03-09 14:13:49 +000048 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
49 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +010050
telsoa014fcda012018-03-09 14:13:49 +000051 virtual void Allocate() override
52 {
David Monahan3fb7e102019-08-20 11:25:29 +010053 // If we have enabled Importing, don't Allocate the tensor
54 if (!m_IsImportEnabled)
55 {
56 armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
57 }
telsoa014fcda012018-03-09 14:13:49 +000058 };
59
telsoa01c577f2c2018-08-31 09:22:23 +010060 virtual void Manage() override
61 {
David Monahan3fb7e102019-08-20 11:25:29 +010062 // If we have enabled Importing, don't manage the tensor
63 if (!m_IsImportEnabled)
64 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010065 ARMNN_ASSERT(m_MemoryGroup != nullptr);
David Monahan3fb7e102019-08-20 11:25:29 +010066 m_MemoryGroup->manage(&m_Tensor);
67 }
telsoa01c577f2c2018-08-31 09:22:23 +010068 }
69
telsoa01c577f2c2018-08-31 09:22:23 +010070 virtual ITensorHandle* GetParent() const override { return nullptr; }
71
telsoa014fcda012018-03-09 14:13:49 +000072 virtual arm_compute::DataType GetDataType() const override
73 {
74 return m_Tensor.info()->data_type();
75 }
76
telsoa01c577f2c2018-08-31 09:22:23 +010077 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
78 {
Jan Eilers3c9e0452020-04-10 13:00:44 +010079 m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup);
telsoa01c577f2c2018-08-31 09:22:23 +010080 }
81
82 virtual const void* Map(bool /* blocking = true */) const override
83 {
84 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
85 }
telsoa01c577f2c2018-08-31 09:22:23 +010086
David Monahan3fb7e102019-08-20 11:25:29 +010087 virtual void Unmap() const override {}
telsoa01c577f2c2018-08-31 09:22:23 +010088
89 TensorShape GetStrides() const override
90 {
91 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
92 }
93
94 TensorShape GetShape() const override
95 {
96 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
97 }
98
David Monahan3fb7e102019-08-20 11:25:29 +010099 void SetImportFlags(MemorySourceFlags importFlags)
100 {
101 m_ImportFlags = importFlags;
102 }
103
104 MemorySourceFlags GetImportFlags() const override
105 {
106 return m_ImportFlags;
107 }
108
109 void SetImportEnabledFlag(bool importEnabledFlag)
110 {
111 m_IsImportEnabled = importEnabledFlag;
112 }
113
114 virtual bool Import(void* memory, MemorySource source) override
115 {
116 if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
117 {
118 if (source == MemorySource::Malloc && m_IsImportEnabled)
119 {
120 // Checks the 16 byte memory alignment
121 constexpr uintptr_t alignment = sizeof(size_t);
122 if (reinterpret_cast<uintptr_t>(memory) % alignment)
123 {
124 throw MemoryImportException("NeonTensorHandle::Import Attempting to import unaligned memory");
125 }
126
127 // m_Tensor not yet Allocated
128 if (!m_Imported && !m_Tensor.buffer())
129 {
130 arm_compute::Status status = m_Tensor.allocator()->import_memory(memory);
131 // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
132 // with the Status error message
133 m_Imported = bool(status);
134 if (!m_Imported)
135 {
136 throw MemoryImportException(status.error_description());
137 }
138 return m_Imported;
139 }
140
141 // m_Tensor.buffer() initially allocated with Allocate().
142 if (!m_Imported && m_Tensor.buffer())
143 {
144 throw MemoryImportException(
145 "NeonTensorHandle::Import Attempting to import on an already allocated tensor");
146 }
147
148 // m_Tensor.buffer() previously imported.
149 if (m_Imported)
150 {
151 arm_compute::Status status = m_Tensor.allocator()->import_memory(memory);
152 // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
153 // with the Status error message
154 m_Imported = bool(status);
155 if (!m_Imported)
156 {
157 throw MemoryImportException(status.error_description());
158 }
159 return m_Imported;
160 }
161 }
Narumol Prangnawarata2493a02020-08-19 14:39:07 +0100162 else
163 {
164 throw MemoryImportException("NeonTensorHandle::Import is disabled");
165 }
166 }
167 else
168 {
169 throw MemoryImportException("NeonTensorHandle::Incorrect import flag");
David Monahan3fb7e102019-08-20 11:25:29 +0100170 }
171 return false;
172 }
173
telsoa014fcda012018-03-09 14:13:49 +0000174private:
David Beck09e2f272018-10-30 11:38:41 +0000175 // Only used for testing
176 void CopyOutTo(void* memory) const override
177 {
178 switch (this->GetDataType())
179 {
180 case arm_compute::DataType::F32:
181 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
182 static_cast<float*>(memory));
183 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000184 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000185 case arm_compute::DataType::QASYMM8:
186 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
187 static_cast<uint8_t*>(memory));
188 break;
Sadik Armagane5d0b932020-04-09 15:48:44 +0100189 case arm_compute::DataType::QASYMM8_SIGNED:
190 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
191 static_cast<int8_t*>(memory));
192 break;
Narumol Prangnawarat250d3922020-03-30 16:11:04 +0100193 case arm_compute::DataType::BFLOAT16:
194 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
195 static_cast<armnn::BFloat16*>(memory));
196 break;
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100197 case arm_compute::DataType::F16:
198 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
199 static_cast<armnn::Half*>(memory));
200 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100201 case arm_compute::DataType::S16:
202 case arm_compute::DataType::QSYMM16:
203 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
204 static_cast<int16_t*>(memory));
205 break;
James Conroyd47a0642019-09-17 14:22:06 +0100206 case arm_compute::DataType::S32:
207 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
208 static_cast<int32_t*>(memory));
209 break;
David Beck09e2f272018-10-30 11:38:41 +0000210 default:
211 {
212 throw armnn::UnimplementedException();
213 }
214 }
215 }
216
217 // Only used for testing
218 void CopyInFrom(const void* memory) override
219 {
220 switch (this->GetDataType())
221 {
222 case arm_compute::DataType::F32:
223 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
224 this->GetTensor());
225 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000226 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000227 case arm_compute::DataType::QASYMM8:
228 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
229 this->GetTensor());
230 break;
Sadik Armagane5d0b932020-04-09 15:48:44 +0100231 case arm_compute::DataType::QASYMM8_SIGNED:
232 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
233 this->GetTensor());
234 break;
Narumol Prangnawarat250d3922020-03-30 16:11:04 +0100235 case arm_compute::DataType::BFLOAT16:
236 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::BFloat16*>(memory),
237 this->GetTensor());
238 break;
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100239 case arm_compute::DataType::F16:
240 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
241 this->GetTensor());
242 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100243 case arm_compute::DataType::S16:
244 case arm_compute::DataType::QSYMM16:
245 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
246 this->GetTensor());
247 break;
James Conroyd47a0642019-09-17 14:22:06 +0100248 case arm_compute::DataType::S32:
249 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
250 this->GetTensor());
251 break;
David Beck09e2f272018-10-30 11:38:41 +0000252 default:
253 {
254 throw armnn::UnimplementedException();
255 }
256 }
257 }
258
telsoa014fcda012018-03-09 14:13:49 +0000259 arm_compute::Tensor m_Tensor;
telsoa01c577f2c2018-08-31 09:22:23 +0100260 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
David Monahan3fb7e102019-08-20 11:25:29 +0100261 MemorySourceFlags m_ImportFlags;
262 bool m_Imported;
263 bool m_IsImportEnabled;
telsoa014fcda012018-03-09 14:13:49 +0000264};
265
Derek Lambertic81855f2019-06-13 17:34:19 +0100266class NeonSubTensorHandle : public IAclTensorHandle
telsoa014fcda012018-03-09 14:13:49 +0000267{
268public:
Derek Lambertic81855f2019-06-13 17:34:19 +0100269 NeonSubTensorHandle(IAclTensorHandle* parent,
telsoa01c577f2c2018-08-31 09:22:23 +0100270 const arm_compute::TensorShape& shape,
271 const arm_compute::Coordinates& coords)
272 : m_Tensor(&parent->GetTensor(), shape, coords)
telsoa014fcda012018-03-09 14:13:49 +0000273 {
telsoa01c577f2c2018-08-31 09:22:23 +0100274 parentHandle = parent;
telsoa014fcda012018-03-09 14:13:49 +0000275 }
276
277 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
278 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +0100279
280 virtual void Allocate() override {}
281 virtual void Manage() override {}
telsoa014fcda012018-03-09 14:13:49 +0000282
telsoa01c577f2c2018-08-31 09:22:23 +0100283 virtual ITensorHandle* GetParent() const override { return parentHandle; }
284
telsoa014fcda012018-03-09 14:13:49 +0000285 virtual arm_compute::DataType GetDataType() const override
286 {
287 return m_Tensor.info()->data_type();
288 }
289
telsoa01c577f2c2018-08-31 09:22:23 +0100290 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
291
292 virtual const void* Map(bool /* blocking = true */) const override
293 {
294 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
295 }
296 virtual void Unmap() const override {}
297
298 TensorShape GetStrides() const override
299 {
300 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
301 }
302
303 TensorShape GetShape() const override
304 {
305 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
306 }
David Beck09e2f272018-10-30 11:38:41 +0000307
telsoa014fcda012018-03-09 14:13:49 +0000308private:
David Beck09e2f272018-10-30 11:38:41 +0000309 // Only used for testing
310 void CopyOutTo(void* memory) const override
311 {
312 switch (this->GetDataType())
313 {
314 case arm_compute::DataType::F32:
315 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
316 static_cast<float*>(memory));
317 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000318 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000319 case arm_compute::DataType::QASYMM8:
320 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
321 static_cast<uint8_t*>(memory));
322 break;
Sadik Armagane5d0b932020-04-09 15:48:44 +0100323 case arm_compute::DataType::QASYMM8_SIGNED:
324 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
325 static_cast<int8_t*>(memory));
326 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100327 case arm_compute::DataType::S16:
328 case arm_compute::DataType::QSYMM16:
329 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
330 static_cast<int16_t*>(memory));
331 break;
James Conroyd47a0642019-09-17 14:22:06 +0100332 case arm_compute::DataType::S32:
333 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
334 static_cast<int32_t*>(memory));
335 break;
David Beck09e2f272018-10-30 11:38:41 +0000336 default:
337 {
338 throw armnn::UnimplementedException();
339 }
340 }
341 }
342
343 // Only used for testing
344 void CopyInFrom(const void* memory) override
345 {
346 switch (this->GetDataType())
347 {
348 case arm_compute::DataType::F32:
349 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
350 this->GetTensor());
351 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000352 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000353 case arm_compute::DataType::QASYMM8:
354 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
355 this->GetTensor());
356 break;
Sadik Armagane5d0b932020-04-09 15:48:44 +0100357 case arm_compute::DataType::QASYMM8_SIGNED:
358 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
359 this->GetTensor());
360 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100361 case arm_compute::DataType::S16:
362 case arm_compute::DataType::QSYMM16:
363 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
364 this->GetTensor());
365 break;
James Conroyd47a0642019-09-17 14:22:06 +0100366 case arm_compute::DataType::S32:
367 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
368 this->GetTensor());
369 break;
David Beck09e2f272018-10-30 11:38:41 +0000370 default:
371 {
372 throw armnn::UnimplementedException();
373 }
374 }
375 }
376
telsoa01c577f2c2018-08-31 09:22:23 +0100377 arm_compute::SubTensor m_Tensor;
378 ITensorHandle* parentHandle = nullptr;
telsoa014fcda012018-03-09 14:13:49 +0000379};
380
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100381} // namespace armnn