blob: f251034823bc66ba32c7277f1ec85e06739390c5 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
Narumol Prangnawarat250d3922020-03-30 16:11:04 +01007#include <BFloat16.hpp>
Aron Virginas-Tar99836d32019-09-30 16:34:31 +01008#include <Half.hpp>
9
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010010#include <armnn/utility/Assert.hpp>
11
Derek Lambertic81855f2019-06-13 17:34:19 +010012#include <aclCommon/ArmComputeTensorHandle.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000013#include <aclCommon/ArmComputeTensorUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000014
telsoa01c577f2c2018-08-31 09:22:23 +010015#include <arm_compute/runtime/MemoryGroup.h>
16#include <arm_compute/runtime/IMemoryGroup.h>
telsoa014fcda012018-03-09 14:13:49 +000017#include <arm_compute/runtime/Tensor.h>
18#include <arm_compute/runtime/SubTensor.h>
19#include <arm_compute/core/TensorShape.h>
20#include <arm_compute/core/Coordinates.h>
21
telsoa01c577f2c2018-08-31 09:22:23 +010022#include <boost/polymorphic_pointer_cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000023
24namespace armnn
25{
26
Derek Lambertic81855f2019-06-13 17:34:19 +010027class NeonTensorHandle : public IAclTensorHandle
telsoa014fcda012018-03-09 14:13:49 +000028{
29public:
30 NeonTensorHandle(const TensorInfo& tensorInfo)
David Monahan3fb7e102019-08-20 11:25:29 +010031 : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc)),
32 m_Imported(false),
33 m_IsImportEnabled(false)
telsoa014fcda012018-03-09 14:13:49 +000034 {
35 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
36 }
37
David Monahan3fb7e102019-08-20 11:25:29 +010038 NeonTensorHandle(const TensorInfo& tensorInfo,
39 DataLayout dataLayout,
40 MemorySourceFlags importFlags = static_cast<MemorySourceFlags>(MemorySource::Malloc))
41 : m_ImportFlags(importFlags),
42 m_Imported(false),
43 m_IsImportEnabled(false)
44
Francis Murtagh351d13d2018-09-24 15:01:18 +010045 {
46 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
47 }
48
telsoa014fcda012018-03-09 14:13:49 +000049 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
50 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +010051
telsoa014fcda012018-03-09 14:13:49 +000052 virtual void Allocate() override
53 {
David Monahan3fb7e102019-08-20 11:25:29 +010054 // If we have enabled Importing, don't Allocate the tensor
55 if (!m_IsImportEnabled)
56 {
57 armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
58 }
telsoa014fcda012018-03-09 14:13:49 +000059 };
60
telsoa01c577f2c2018-08-31 09:22:23 +010061 virtual void Manage() override
62 {
David Monahan3fb7e102019-08-20 11:25:29 +010063 // If we have enabled Importing, don't manage the tensor
64 if (!m_IsImportEnabled)
65 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010066 ARMNN_ASSERT(m_MemoryGroup != nullptr);
David Monahan3fb7e102019-08-20 11:25:29 +010067 m_MemoryGroup->manage(&m_Tensor);
68 }
telsoa01c577f2c2018-08-31 09:22:23 +010069 }
70
telsoa01c577f2c2018-08-31 09:22:23 +010071 virtual ITensorHandle* GetParent() const override { return nullptr; }
72
telsoa014fcda012018-03-09 14:13:49 +000073 virtual arm_compute::DataType GetDataType() const override
74 {
75 return m_Tensor.info()->data_type();
76 }
77
telsoa01c577f2c2018-08-31 09:22:23 +010078 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
79 {
80 m_MemoryGroup = boost::polymorphic_pointer_downcast<arm_compute::MemoryGroup>(memoryGroup);
81 }
82
83 virtual const void* Map(bool /* blocking = true */) const override
84 {
85 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
86 }
telsoa01c577f2c2018-08-31 09:22:23 +010087
David Monahan3fb7e102019-08-20 11:25:29 +010088 virtual void Unmap() const override {}
telsoa01c577f2c2018-08-31 09:22:23 +010089
90 TensorShape GetStrides() const override
91 {
92 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
93 }
94
95 TensorShape GetShape() const override
96 {
97 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
98 }
99
David Monahan3fb7e102019-08-20 11:25:29 +0100100 void SetImportFlags(MemorySourceFlags importFlags)
101 {
102 m_ImportFlags = importFlags;
103 }
104
105 MemorySourceFlags GetImportFlags() const override
106 {
107 return m_ImportFlags;
108 }
109
110 void SetImportEnabledFlag(bool importEnabledFlag)
111 {
112 m_IsImportEnabled = importEnabledFlag;
113 }
114
115 virtual bool Import(void* memory, MemorySource source) override
116 {
117 if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
118 {
119 if (source == MemorySource::Malloc && m_IsImportEnabled)
120 {
121 // Checks the 16 byte memory alignment
122 constexpr uintptr_t alignment = sizeof(size_t);
123 if (reinterpret_cast<uintptr_t>(memory) % alignment)
124 {
125 throw MemoryImportException("NeonTensorHandle::Import Attempting to import unaligned memory");
126 }
127
128 // m_Tensor not yet Allocated
129 if (!m_Imported && !m_Tensor.buffer())
130 {
131 arm_compute::Status status = m_Tensor.allocator()->import_memory(memory);
132 // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
133 // with the Status error message
134 m_Imported = bool(status);
135 if (!m_Imported)
136 {
137 throw MemoryImportException(status.error_description());
138 }
139 return m_Imported;
140 }
141
142 // m_Tensor.buffer() initially allocated with Allocate().
143 if (!m_Imported && m_Tensor.buffer())
144 {
145 throw MemoryImportException(
146 "NeonTensorHandle::Import Attempting to import on an already allocated tensor");
147 }
148
149 // m_Tensor.buffer() previously imported.
150 if (m_Imported)
151 {
152 arm_compute::Status status = m_Tensor.allocator()->import_memory(memory);
153 // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
154 // with the Status error message
155 m_Imported = bool(status);
156 if (!m_Imported)
157 {
158 throw MemoryImportException(status.error_description());
159 }
160 return m_Imported;
161 }
162 }
163 }
164 return false;
165 }
166
telsoa014fcda012018-03-09 14:13:49 +0000167private:
David Beck09e2f272018-10-30 11:38:41 +0000168 // Only used for testing
169 void CopyOutTo(void* memory) const override
170 {
171 switch (this->GetDataType())
172 {
173 case arm_compute::DataType::F32:
174 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
175 static_cast<float*>(memory));
176 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000177 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000178 case arm_compute::DataType::QASYMM8:
179 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
180 static_cast<uint8_t*>(memory));
181 break;
Sadik Armagane5d0b932020-04-09 15:48:44 +0100182 case arm_compute::DataType::QASYMM8_SIGNED:
183 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
184 static_cast<int8_t*>(memory));
185 break;
Narumol Prangnawarat250d3922020-03-30 16:11:04 +0100186 case arm_compute::DataType::BFLOAT16:
187 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
188 static_cast<armnn::BFloat16*>(memory));
189 break;
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100190 case arm_compute::DataType::F16:
191 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
192 static_cast<armnn::Half*>(memory));
193 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100194 case arm_compute::DataType::S16:
195 case arm_compute::DataType::QSYMM16:
196 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
197 static_cast<int16_t*>(memory));
198 break;
James Conroyd47a0642019-09-17 14:22:06 +0100199 case arm_compute::DataType::S32:
200 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
201 static_cast<int32_t*>(memory));
202 break;
David Beck09e2f272018-10-30 11:38:41 +0000203 default:
204 {
205 throw armnn::UnimplementedException();
206 }
207 }
208 }
209
210 // Only used for testing
211 void CopyInFrom(const void* memory) override
212 {
213 switch (this->GetDataType())
214 {
215 case arm_compute::DataType::F32:
216 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
217 this->GetTensor());
218 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000219 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000220 case arm_compute::DataType::QASYMM8:
221 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
222 this->GetTensor());
223 break;
Sadik Armagane5d0b932020-04-09 15:48:44 +0100224 case arm_compute::DataType::QASYMM8_SIGNED:
225 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
226 this->GetTensor());
227 break;
Narumol Prangnawarat250d3922020-03-30 16:11:04 +0100228 case arm_compute::DataType::BFLOAT16:
229 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::BFloat16*>(memory),
230 this->GetTensor());
231 break;
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100232 case arm_compute::DataType::F16:
233 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
234 this->GetTensor());
235 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100236 case arm_compute::DataType::S16:
237 case arm_compute::DataType::QSYMM16:
238 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
239 this->GetTensor());
240 break;
James Conroyd47a0642019-09-17 14:22:06 +0100241 case arm_compute::DataType::S32:
242 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
243 this->GetTensor());
244 break;
David Beck09e2f272018-10-30 11:38:41 +0000245 default:
246 {
247 throw armnn::UnimplementedException();
248 }
249 }
250 }
251
telsoa014fcda012018-03-09 14:13:49 +0000252 arm_compute::Tensor m_Tensor;
telsoa01c577f2c2018-08-31 09:22:23 +0100253 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
David Monahan3fb7e102019-08-20 11:25:29 +0100254 MemorySourceFlags m_ImportFlags;
255 bool m_Imported;
256 bool m_IsImportEnabled;
telsoa014fcda012018-03-09 14:13:49 +0000257};
258
Derek Lambertic81855f2019-06-13 17:34:19 +0100259class NeonSubTensorHandle : public IAclTensorHandle
telsoa014fcda012018-03-09 14:13:49 +0000260{
261public:
Derek Lambertic81855f2019-06-13 17:34:19 +0100262 NeonSubTensorHandle(IAclTensorHandle* parent,
telsoa01c577f2c2018-08-31 09:22:23 +0100263 const arm_compute::TensorShape& shape,
264 const arm_compute::Coordinates& coords)
265 : m_Tensor(&parent->GetTensor(), shape, coords)
telsoa014fcda012018-03-09 14:13:49 +0000266 {
telsoa01c577f2c2018-08-31 09:22:23 +0100267 parentHandle = parent;
telsoa014fcda012018-03-09 14:13:49 +0000268 }
269
270 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
271 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +0100272
273 virtual void Allocate() override {}
274 virtual void Manage() override {}
telsoa014fcda012018-03-09 14:13:49 +0000275
telsoa01c577f2c2018-08-31 09:22:23 +0100276 virtual ITensorHandle* GetParent() const override { return parentHandle; }
277
telsoa014fcda012018-03-09 14:13:49 +0000278 virtual arm_compute::DataType GetDataType() const override
279 {
280 return m_Tensor.info()->data_type();
281 }
282
telsoa01c577f2c2018-08-31 09:22:23 +0100283 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
284
285 virtual const void* Map(bool /* blocking = true */) const override
286 {
287 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
288 }
289 virtual void Unmap() const override {}
290
291 TensorShape GetStrides() const override
292 {
293 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
294 }
295
296 TensorShape GetShape() const override
297 {
298 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
299 }
David Beck09e2f272018-10-30 11:38:41 +0000300
telsoa014fcda012018-03-09 14:13:49 +0000301private:
David Beck09e2f272018-10-30 11:38:41 +0000302 // Only used for testing
303 void CopyOutTo(void* memory) const override
304 {
305 switch (this->GetDataType())
306 {
307 case arm_compute::DataType::F32:
308 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
309 static_cast<float*>(memory));
310 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000311 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000312 case arm_compute::DataType::QASYMM8:
313 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
314 static_cast<uint8_t*>(memory));
315 break;
Sadik Armagane5d0b932020-04-09 15:48:44 +0100316 case arm_compute::DataType::QASYMM8_SIGNED:
317 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
318 static_cast<int8_t*>(memory));
319 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100320 case arm_compute::DataType::S16:
321 case arm_compute::DataType::QSYMM16:
322 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
323 static_cast<int16_t*>(memory));
324 break;
James Conroyd47a0642019-09-17 14:22:06 +0100325 case arm_compute::DataType::S32:
326 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
327 static_cast<int32_t*>(memory));
328 break;
David Beck09e2f272018-10-30 11:38:41 +0000329 default:
330 {
331 throw armnn::UnimplementedException();
332 }
333 }
334 }
335
336 // Only used for testing
337 void CopyInFrom(const void* memory) override
338 {
339 switch (this->GetDataType())
340 {
341 case arm_compute::DataType::F32:
342 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
343 this->GetTensor());
344 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000345 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000346 case arm_compute::DataType::QASYMM8:
347 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
348 this->GetTensor());
349 break;
Sadik Armagane5d0b932020-04-09 15:48:44 +0100350 case arm_compute::DataType::QASYMM8_SIGNED:
351 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
352 this->GetTensor());
353 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100354 case arm_compute::DataType::S16:
355 case arm_compute::DataType::QSYMM16:
356 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
357 this->GetTensor());
358 break;
James Conroyd47a0642019-09-17 14:22:06 +0100359 case arm_compute::DataType::S32:
360 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
361 this->GetTensor());
362 break;
David Beck09e2f272018-10-30 11:38:41 +0000363 default:
364 {
365 throw armnn::UnimplementedException();
366 }
367 }
368 }
369
telsoa01c577f2c2018-08-31 09:22:23 +0100370 arm_compute::SubTensor m_Tensor;
371 ITensorHandle* parentHandle = nullptr;
telsoa014fcda012018-03-09 14:13:49 +0000372};
373
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100374} // namespace armnn