blob: 11d20878d76e447b0d1896be10be766ee3c214f5 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
Narumol Prangnawarat250d3922020-03-30 16:11:04 +01007#include <BFloat16.hpp>
Aron Virginas-Tar99836d32019-09-30 16:34:31 +01008#include <Half.hpp>
9
Derek Lambertic81855f2019-06-13 17:34:19 +010010#include <aclCommon/ArmComputeTensorHandle.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000011#include <aclCommon/ArmComputeTensorUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000012
telsoa01c577f2c2018-08-31 09:22:23 +010013#include <arm_compute/runtime/MemoryGroup.h>
14#include <arm_compute/runtime/IMemoryGroup.h>
telsoa014fcda012018-03-09 14:13:49 +000015#include <arm_compute/runtime/Tensor.h>
16#include <arm_compute/runtime/SubTensor.h>
17#include <arm_compute/core/TensorShape.h>
18#include <arm_compute/core/Coordinates.h>
19
telsoa01c577f2c2018-08-31 09:22:23 +010020#include <boost/polymorphic_pointer_cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000021
22namespace armnn
23{
24
Derek Lambertic81855f2019-06-13 17:34:19 +010025class NeonTensorHandle : public IAclTensorHandle
telsoa014fcda012018-03-09 14:13:49 +000026{
27public:
28 NeonTensorHandle(const TensorInfo& tensorInfo)
David Monahan3fb7e102019-08-20 11:25:29 +010029 : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc)),
30 m_Imported(false),
31 m_IsImportEnabled(false)
telsoa014fcda012018-03-09 14:13:49 +000032 {
33 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
34 }
35
David Monahan3fb7e102019-08-20 11:25:29 +010036 NeonTensorHandle(const TensorInfo& tensorInfo,
37 DataLayout dataLayout,
38 MemorySourceFlags importFlags = static_cast<MemorySourceFlags>(MemorySource::Malloc))
39 : m_ImportFlags(importFlags),
40 m_Imported(false),
41 m_IsImportEnabled(false)
42
Francis Murtagh351d13d2018-09-24 15:01:18 +010043 {
44 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
45 }
46
telsoa014fcda012018-03-09 14:13:49 +000047 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
48 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +010049
telsoa014fcda012018-03-09 14:13:49 +000050 virtual void Allocate() override
51 {
David Monahan3fb7e102019-08-20 11:25:29 +010052 // If we have enabled Importing, don't Allocate the tensor
53 if (!m_IsImportEnabled)
54 {
55 armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
56 }
telsoa014fcda012018-03-09 14:13:49 +000057 };
58
telsoa01c577f2c2018-08-31 09:22:23 +010059 virtual void Manage() override
60 {
David Monahan3fb7e102019-08-20 11:25:29 +010061 // If we have enabled Importing, don't manage the tensor
62 if (!m_IsImportEnabled)
63 {
64 BOOST_ASSERT(m_MemoryGroup != nullptr);
65 m_MemoryGroup->manage(&m_Tensor);
66 }
telsoa01c577f2c2018-08-31 09:22:23 +010067 }
68
telsoa01c577f2c2018-08-31 09:22:23 +010069 virtual ITensorHandle* GetParent() const override { return nullptr; }
70
telsoa014fcda012018-03-09 14:13:49 +000071 virtual arm_compute::DataType GetDataType() const override
72 {
73 return m_Tensor.info()->data_type();
74 }
75
telsoa01c577f2c2018-08-31 09:22:23 +010076 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
77 {
78 m_MemoryGroup = boost::polymorphic_pointer_downcast<arm_compute::MemoryGroup>(memoryGroup);
79 }
80
81 virtual const void* Map(bool /* blocking = true */) const override
82 {
83 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
84 }
telsoa01c577f2c2018-08-31 09:22:23 +010085
David Monahan3fb7e102019-08-20 11:25:29 +010086 virtual void Unmap() const override {}
telsoa01c577f2c2018-08-31 09:22:23 +010087
88 TensorShape GetStrides() const override
89 {
90 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
91 }
92
93 TensorShape GetShape() const override
94 {
95 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
96 }
97
David Monahan3fb7e102019-08-20 11:25:29 +010098 void SetImportFlags(MemorySourceFlags importFlags)
99 {
100 m_ImportFlags = importFlags;
101 }
102
103 MemorySourceFlags GetImportFlags() const override
104 {
105 return m_ImportFlags;
106 }
107
108 void SetImportEnabledFlag(bool importEnabledFlag)
109 {
110 m_IsImportEnabled = importEnabledFlag;
111 }
112
113 virtual bool Import(void* memory, MemorySource source) override
114 {
115 if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
116 {
117 if (source == MemorySource::Malloc && m_IsImportEnabled)
118 {
119 // Checks the 16 byte memory alignment
120 constexpr uintptr_t alignment = sizeof(size_t);
121 if (reinterpret_cast<uintptr_t>(memory) % alignment)
122 {
123 throw MemoryImportException("NeonTensorHandle::Import Attempting to import unaligned memory");
124 }
125
126 // m_Tensor not yet Allocated
127 if (!m_Imported && !m_Tensor.buffer())
128 {
129 arm_compute::Status status = m_Tensor.allocator()->import_memory(memory);
130 // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
131 // with the Status error message
132 m_Imported = bool(status);
133 if (!m_Imported)
134 {
135 throw MemoryImportException(status.error_description());
136 }
137 return m_Imported;
138 }
139
140 // m_Tensor.buffer() initially allocated with Allocate().
141 if (!m_Imported && m_Tensor.buffer())
142 {
143 throw MemoryImportException(
144 "NeonTensorHandle::Import Attempting to import on an already allocated tensor");
145 }
146
147 // m_Tensor.buffer() previously imported.
148 if (m_Imported)
149 {
150 arm_compute::Status status = m_Tensor.allocator()->import_memory(memory);
151 // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
152 // with the Status error message
153 m_Imported = bool(status);
154 if (!m_Imported)
155 {
156 throw MemoryImportException(status.error_description());
157 }
158 return m_Imported;
159 }
160 }
161 }
162 return false;
163 }
164
telsoa014fcda012018-03-09 14:13:49 +0000165private:
David Beck09e2f272018-10-30 11:38:41 +0000166 // Only used for testing
167 void CopyOutTo(void* memory) const override
168 {
169 switch (this->GetDataType())
170 {
171 case arm_compute::DataType::F32:
172 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
173 static_cast<float*>(memory));
174 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000175 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000176 case arm_compute::DataType::QASYMM8:
177 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
178 static_cast<uint8_t*>(memory));
179 break;
Narumol Prangnawarat250d3922020-03-30 16:11:04 +0100180 case arm_compute::DataType::BFLOAT16:
181 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
182 static_cast<armnn::BFloat16*>(memory));
183 break;
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100184 case arm_compute::DataType::F16:
185 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
186 static_cast<armnn::Half*>(memory));
187 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100188 case arm_compute::DataType::S16:
189 case arm_compute::DataType::QSYMM16:
190 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
191 static_cast<int16_t*>(memory));
192 break;
James Conroyd47a0642019-09-17 14:22:06 +0100193 case arm_compute::DataType::S32:
194 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
195 static_cast<int32_t*>(memory));
196 break;
David Beck09e2f272018-10-30 11:38:41 +0000197 default:
198 {
199 throw armnn::UnimplementedException();
200 }
201 }
202 }
203
204 // Only used for testing
205 void CopyInFrom(const void* memory) override
206 {
207 switch (this->GetDataType())
208 {
209 case arm_compute::DataType::F32:
210 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
211 this->GetTensor());
212 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000213 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000214 case arm_compute::DataType::QASYMM8:
215 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
216 this->GetTensor());
217 break;
Narumol Prangnawarat250d3922020-03-30 16:11:04 +0100218 case arm_compute::DataType::BFLOAT16:
219 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::BFloat16*>(memory),
220 this->GetTensor());
221 break;
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100222 case arm_compute::DataType::F16:
223 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
224 this->GetTensor());
225 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100226 case arm_compute::DataType::S16:
227 case arm_compute::DataType::QSYMM16:
228 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
229 this->GetTensor());
230 break;
James Conroyd47a0642019-09-17 14:22:06 +0100231 case arm_compute::DataType::S32:
232 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
233 this->GetTensor());
234 break;
David Beck09e2f272018-10-30 11:38:41 +0000235 default:
236 {
237 throw armnn::UnimplementedException();
238 }
239 }
240 }
241
telsoa014fcda012018-03-09 14:13:49 +0000242 arm_compute::Tensor m_Tensor;
telsoa01c577f2c2018-08-31 09:22:23 +0100243 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
David Monahan3fb7e102019-08-20 11:25:29 +0100244 MemorySourceFlags m_ImportFlags;
245 bool m_Imported;
246 bool m_IsImportEnabled;
telsoa014fcda012018-03-09 14:13:49 +0000247};
248
Derek Lambertic81855f2019-06-13 17:34:19 +0100249class NeonSubTensorHandle : public IAclTensorHandle
telsoa014fcda012018-03-09 14:13:49 +0000250{
251public:
Derek Lambertic81855f2019-06-13 17:34:19 +0100252 NeonSubTensorHandle(IAclTensorHandle* parent,
telsoa01c577f2c2018-08-31 09:22:23 +0100253 const arm_compute::TensorShape& shape,
254 const arm_compute::Coordinates& coords)
255 : m_Tensor(&parent->GetTensor(), shape, coords)
telsoa014fcda012018-03-09 14:13:49 +0000256 {
telsoa01c577f2c2018-08-31 09:22:23 +0100257 parentHandle = parent;
telsoa014fcda012018-03-09 14:13:49 +0000258 }
259
260 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
261 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +0100262
263 virtual void Allocate() override {}
264 virtual void Manage() override {}
telsoa014fcda012018-03-09 14:13:49 +0000265
telsoa01c577f2c2018-08-31 09:22:23 +0100266 virtual ITensorHandle* GetParent() const override { return parentHandle; }
267
telsoa014fcda012018-03-09 14:13:49 +0000268 virtual arm_compute::DataType GetDataType() const override
269 {
270 return m_Tensor.info()->data_type();
271 }
272
telsoa01c577f2c2018-08-31 09:22:23 +0100273 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
274
275 virtual const void* Map(bool /* blocking = true */) const override
276 {
277 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
278 }
279 virtual void Unmap() const override {}
280
281 TensorShape GetStrides() const override
282 {
283 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
284 }
285
286 TensorShape GetShape() const override
287 {
288 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
289 }
David Beck09e2f272018-10-30 11:38:41 +0000290
telsoa014fcda012018-03-09 14:13:49 +0000291private:
David Beck09e2f272018-10-30 11:38:41 +0000292 // Only used for testing
293 void CopyOutTo(void* memory) const override
294 {
295 switch (this->GetDataType())
296 {
297 case arm_compute::DataType::F32:
298 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
299 static_cast<float*>(memory));
300 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000301 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000302 case arm_compute::DataType::QASYMM8:
303 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
304 static_cast<uint8_t*>(memory));
305 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100306 case arm_compute::DataType::S16:
307 case arm_compute::DataType::QSYMM16:
308 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
309 static_cast<int16_t*>(memory));
310 break;
James Conroyd47a0642019-09-17 14:22:06 +0100311 case arm_compute::DataType::S32:
312 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
313 static_cast<int32_t*>(memory));
314 break;
David Beck09e2f272018-10-30 11:38:41 +0000315 default:
316 {
317 throw armnn::UnimplementedException();
318 }
319 }
320 }
321
322 // Only used for testing
323 void CopyInFrom(const void* memory) override
324 {
325 switch (this->GetDataType())
326 {
327 case arm_compute::DataType::F32:
328 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
329 this->GetTensor());
330 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000331 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000332 case arm_compute::DataType::QASYMM8:
333 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
334 this->GetTensor());
335 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100336 case arm_compute::DataType::S16:
337 case arm_compute::DataType::QSYMM16:
338 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
339 this->GetTensor());
340 break;
James Conroyd47a0642019-09-17 14:22:06 +0100341 case arm_compute::DataType::S32:
342 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
343 this->GetTensor());
344 break;
David Beck09e2f272018-10-30 11:38:41 +0000345 default:
346 {
347 throw armnn::UnimplementedException();
348 }
349 }
350 }
351
telsoa01c577f2c2018-08-31 09:22:23 +0100352 arm_compute::SubTensor m_Tensor;
353 ITensorHandle* parentHandle = nullptr;
telsoa014fcda012018-03-09 14:13:49 +0000354};
355
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100356} // namespace armnn