blob: 4cc610c85a4e617d249bb97a0b87e7045c610afb [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
Narumol Prangnawarat250d3922020-03-30 16:11:04 +01007#include <BFloat16.hpp>
Aron Virginas-Tar99836d32019-09-30 16:34:31 +01008#include <Half.hpp>
9
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010010#include <armnn/utility/Assert.hpp>
11
Derek Lambertic81855f2019-06-13 17:34:19 +010012#include <aclCommon/ArmComputeTensorHandle.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000013#include <aclCommon/ArmComputeTensorUtils.hpp>
Jan Eilers3c9e0452020-04-10 13:00:44 +010014#include <armnn/utility/PolymorphicDowncast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000015
telsoa01c577f2c2018-08-31 09:22:23 +010016#include <arm_compute/runtime/MemoryGroup.h>
17#include <arm_compute/runtime/IMemoryGroup.h>
telsoa014fcda012018-03-09 14:13:49 +000018#include <arm_compute/runtime/Tensor.h>
19#include <arm_compute/runtime/SubTensor.h>
20#include <arm_compute/core/TensorShape.h>
21#include <arm_compute/core/Coordinates.h>
22
telsoa014fcda012018-03-09 14:13:49 +000023namespace armnn
24{
25
Derek Lambertic81855f2019-06-13 17:34:19 +010026class NeonTensorHandle : public IAclTensorHandle
telsoa014fcda012018-03-09 14:13:49 +000027{
28public:
29 NeonTensorHandle(const TensorInfo& tensorInfo)
David Monahan3fb7e102019-08-20 11:25:29 +010030 : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc)),
31 m_Imported(false),
32 m_IsImportEnabled(false)
telsoa014fcda012018-03-09 14:13:49 +000033 {
34 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
35 }
36
David Monahan3fb7e102019-08-20 11:25:29 +010037 NeonTensorHandle(const TensorInfo& tensorInfo,
38 DataLayout dataLayout,
39 MemorySourceFlags importFlags = static_cast<MemorySourceFlags>(MemorySource::Malloc))
40 : m_ImportFlags(importFlags),
41 m_Imported(false),
42 m_IsImportEnabled(false)
43
Francis Murtagh351d13d2018-09-24 15:01:18 +010044 {
45 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
46 }
47
telsoa014fcda012018-03-09 14:13:49 +000048 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
49 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +010050
telsoa014fcda012018-03-09 14:13:49 +000051 virtual void Allocate() override
52 {
David Monahan3fb7e102019-08-20 11:25:29 +010053 // If we have enabled Importing, don't Allocate the tensor
54 if (!m_IsImportEnabled)
55 {
56 armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
57 }
telsoa014fcda012018-03-09 14:13:49 +000058 };
59
telsoa01c577f2c2018-08-31 09:22:23 +010060 virtual void Manage() override
61 {
David Monahan3fb7e102019-08-20 11:25:29 +010062 // If we have enabled Importing, don't manage the tensor
63 if (!m_IsImportEnabled)
64 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010065 ARMNN_ASSERT(m_MemoryGroup != nullptr);
David Monahan3fb7e102019-08-20 11:25:29 +010066 m_MemoryGroup->manage(&m_Tensor);
67 }
telsoa01c577f2c2018-08-31 09:22:23 +010068 }
69
telsoa01c577f2c2018-08-31 09:22:23 +010070 virtual ITensorHandle* GetParent() const override { return nullptr; }
71
telsoa014fcda012018-03-09 14:13:49 +000072 virtual arm_compute::DataType GetDataType() const override
73 {
74 return m_Tensor.info()->data_type();
75 }
76
telsoa01c577f2c2018-08-31 09:22:23 +010077 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
78 {
Jan Eilers3c9e0452020-04-10 13:00:44 +010079 m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup);
telsoa01c577f2c2018-08-31 09:22:23 +010080 }
81
82 virtual const void* Map(bool /* blocking = true */) const override
83 {
84 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
85 }
telsoa01c577f2c2018-08-31 09:22:23 +010086
David Monahan3fb7e102019-08-20 11:25:29 +010087 virtual void Unmap() const override {}
telsoa01c577f2c2018-08-31 09:22:23 +010088
89 TensorShape GetStrides() const override
90 {
91 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
92 }
93
94 TensorShape GetShape() const override
95 {
96 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
97 }
98
David Monahan3fb7e102019-08-20 11:25:29 +010099 void SetImportFlags(MemorySourceFlags importFlags)
100 {
101 m_ImportFlags = importFlags;
102 }
103
104 MemorySourceFlags GetImportFlags() const override
105 {
106 return m_ImportFlags;
107 }
108
109 void SetImportEnabledFlag(bool importEnabledFlag)
110 {
111 m_IsImportEnabled = importEnabledFlag;
112 }
113
114 virtual bool Import(void* memory, MemorySource source) override
115 {
116 if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
117 {
118 if (source == MemorySource::Malloc && m_IsImportEnabled)
119 {
120 // Checks the 16 byte memory alignment
121 constexpr uintptr_t alignment = sizeof(size_t);
122 if (reinterpret_cast<uintptr_t>(memory) % alignment)
123 {
124 throw MemoryImportException("NeonTensorHandle::Import Attempting to import unaligned memory");
125 }
126
127 // m_Tensor not yet Allocated
128 if (!m_Imported && !m_Tensor.buffer())
129 {
130 arm_compute::Status status = m_Tensor.allocator()->import_memory(memory);
131 // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
132 // with the Status error message
133 m_Imported = bool(status);
134 if (!m_Imported)
135 {
136 throw MemoryImportException(status.error_description());
137 }
138 return m_Imported;
139 }
140
141 // m_Tensor.buffer() initially allocated with Allocate().
142 if (!m_Imported && m_Tensor.buffer())
143 {
144 throw MemoryImportException(
145 "NeonTensorHandle::Import Attempting to import on an already allocated tensor");
146 }
147
148 // m_Tensor.buffer() previously imported.
149 if (m_Imported)
150 {
151 arm_compute::Status status = m_Tensor.allocator()->import_memory(memory);
152 // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
153 // with the Status error message
154 m_Imported = bool(status);
155 if (!m_Imported)
156 {
157 throw MemoryImportException(status.error_description());
158 }
159 return m_Imported;
160 }
161 }
162 }
163 return false;
164 }
165
telsoa014fcda012018-03-09 14:13:49 +0000166private:
David Beck09e2f272018-10-30 11:38:41 +0000167 // Only used for testing
168 void CopyOutTo(void* memory) const override
169 {
170 switch (this->GetDataType())
171 {
172 case arm_compute::DataType::F32:
173 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
174 static_cast<float*>(memory));
175 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000176 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000177 case arm_compute::DataType::QASYMM8:
178 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
179 static_cast<uint8_t*>(memory));
180 break;
Sadik Armagane5d0b932020-04-09 15:48:44 +0100181 case arm_compute::DataType::QASYMM8_SIGNED:
182 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
183 static_cast<int8_t*>(memory));
184 break;
Narumol Prangnawarat250d3922020-03-30 16:11:04 +0100185 case arm_compute::DataType::BFLOAT16:
186 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
187 static_cast<armnn::BFloat16*>(memory));
188 break;
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100189 case arm_compute::DataType::F16:
190 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
191 static_cast<armnn::Half*>(memory));
192 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100193 case arm_compute::DataType::S16:
194 case arm_compute::DataType::QSYMM16:
195 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
196 static_cast<int16_t*>(memory));
197 break;
James Conroyd47a0642019-09-17 14:22:06 +0100198 case arm_compute::DataType::S32:
199 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
200 static_cast<int32_t*>(memory));
201 break;
David Beck09e2f272018-10-30 11:38:41 +0000202 default:
203 {
204 throw armnn::UnimplementedException();
205 }
206 }
207 }
208
209 // Only used for testing
210 void CopyInFrom(const void* memory) override
211 {
212 switch (this->GetDataType())
213 {
214 case arm_compute::DataType::F32:
215 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
216 this->GetTensor());
217 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000218 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000219 case arm_compute::DataType::QASYMM8:
220 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
221 this->GetTensor());
222 break;
Sadik Armagane5d0b932020-04-09 15:48:44 +0100223 case arm_compute::DataType::QASYMM8_SIGNED:
224 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
225 this->GetTensor());
226 break;
Narumol Prangnawarat250d3922020-03-30 16:11:04 +0100227 case arm_compute::DataType::BFLOAT16:
228 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::BFloat16*>(memory),
229 this->GetTensor());
230 break;
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100231 case arm_compute::DataType::F16:
232 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
233 this->GetTensor());
234 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100235 case arm_compute::DataType::S16:
236 case arm_compute::DataType::QSYMM16:
237 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
238 this->GetTensor());
239 break;
James Conroyd47a0642019-09-17 14:22:06 +0100240 case arm_compute::DataType::S32:
241 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
242 this->GetTensor());
243 break;
David Beck09e2f272018-10-30 11:38:41 +0000244 default:
245 {
246 throw armnn::UnimplementedException();
247 }
248 }
249 }
250
telsoa014fcda012018-03-09 14:13:49 +0000251 arm_compute::Tensor m_Tensor;
telsoa01c577f2c2018-08-31 09:22:23 +0100252 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
David Monahan3fb7e102019-08-20 11:25:29 +0100253 MemorySourceFlags m_ImportFlags;
254 bool m_Imported;
255 bool m_IsImportEnabled;
telsoa014fcda012018-03-09 14:13:49 +0000256};
257
Derek Lambertic81855f2019-06-13 17:34:19 +0100258class NeonSubTensorHandle : public IAclTensorHandle
telsoa014fcda012018-03-09 14:13:49 +0000259{
260public:
Derek Lambertic81855f2019-06-13 17:34:19 +0100261 NeonSubTensorHandle(IAclTensorHandle* parent,
telsoa01c577f2c2018-08-31 09:22:23 +0100262 const arm_compute::TensorShape& shape,
263 const arm_compute::Coordinates& coords)
264 : m_Tensor(&parent->GetTensor(), shape, coords)
telsoa014fcda012018-03-09 14:13:49 +0000265 {
telsoa01c577f2c2018-08-31 09:22:23 +0100266 parentHandle = parent;
telsoa014fcda012018-03-09 14:13:49 +0000267 }
268
269 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
270 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +0100271
272 virtual void Allocate() override {}
273 virtual void Manage() override {}
telsoa014fcda012018-03-09 14:13:49 +0000274
telsoa01c577f2c2018-08-31 09:22:23 +0100275 virtual ITensorHandle* GetParent() const override { return parentHandle; }
276
telsoa014fcda012018-03-09 14:13:49 +0000277 virtual arm_compute::DataType GetDataType() const override
278 {
279 return m_Tensor.info()->data_type();
280 }
281
telsoa01c577f2c2018-08-31 09:22:23 +0100282 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
283
284 virtual const void* Map(bool /* blocking = true */) const override
285 {
286 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
287 }
288 virtual void Unmap() const override {}
289
290 TensorShape GetStrides() const override
291 {
292 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
293 }
294
295 TensorShape GetShape() const override
296 {
297 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
298 }
David Beck09e2f272018-10-30 11:38:41 +0000299
telsoa014fcda012018-03-09 14:13:49 +0000300private:
David Beck09e2f272018-10-30 11:38:41 +0000301 // Only used for testing
302 void CopyOutTo(void* memory) const override
303 {
304 switch (this->GetDataType())
305 {
306 case arm_compute::DataType::F32:
307 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
308 static_cast<float*>(memory));
309 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000310 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000311 case arm_compute::DataType::QASYMM8:
312 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
313 static_cast<uint8_t*>(memory));
314 break;
Sadik Armagane5d0b932020-04-09 15:48:44 +0100315 case arm_compute::DataType::QASYMM8_SIGNED:
316 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
317 static_cast<int8_t*>(memory));
318 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100319 case arm_compute::DataType::S16:
320 case arm_compute::DataType::QSYMM16:
321 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
322 static_cast<int16_t*>(memory));
323 break;
James Conroyd47a0642019-09-17 14:22:06 +0100324 case arm_compute::DataType::S32:
325 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
326 static_cast<int32_t*>(memory));
327 break;
David Beck09e2f272018-10-30 11:38:41 +0000328 default:
329 {
330 throw armnn::UnimplementedException();
331 }
332 }
333 }
334
335 // Only used for testing
336 void CopyInFrom(const void* memory) override
337 {
338 switch (this->GetDataType())
339 {
340 case arm_compute::DataType::F32:
341 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
342 this->GetTensor());
343 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000344 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000345 case arm_compute::DataType::QASYMM8:
346 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
347 this->GetTensor());
348 break;
Sadik Armagane5d0b932020-04-09 15:48:44 +0100349 case arm_compute::DataType::QASYMM8_SIGNED:
350 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
351 this->GetTensor());
352 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100353 case arm_compute::DataType::S16:
354 case arm_compute::DataType::QSYMM16:
355 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
356 this->GetTensor());
357 break;
James Conroyd47a0642019-09-17 14:22:06 +0100358 case arm_compute::DataType::S32:
359 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
360 this->GetTensor());
361 break;
David Beck09e2f272018-10-30 11:38:41 +0000362 default:
363 {
364 throw armnn::UnimplementedException();
365 }
366 }
367 }
368
telsoa01c577f2c2018-08-31 09:22:23 +0100369 arm_compute::SubTensor m_Tensor;
370 ITensorHandle* parentHandle = nullptr;
telsoa014fcda012018-03-09 14:13:49 +0000371};
372
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100373} // namespace armnn