blob: fb2c2b51284413e8ebbe7f3febc2af9086b646d3 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
Narumol Prangnawarat250d3922020-03-30 16:11:04 +01007#include <BFloat16.hpp>
Aron Virginas-Tar99836d32019-09-30 16:34:31 +01008#include <Half.hpp>
9
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010010#include <armnn/utility/Assert.hpp>
11
Derek Lambertic81855f2019-06-13 17:34:19 +010012#include <aclCommon/ArmComputeTensorHandle.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000013#include <aclCommon/ArmComputeTensorUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000014
telsoa01c577f2c2018-08-31 09:22:23 +010015#include <arm_compute/runtime/MemoryGroup.h>
16#include <arm_compute/runtime/IMemoryGroup.h>
telsoa014fcda012018-03-09 14:13:49 +000017#include <arm_compute/runtime/Tensor.h>
18#include <arm_compute/runtime/SubTensor.h>
19#include <arm_compute/core/TensorShape.h>
20#include <arm_compute/core/Coordinates.h>
21
telsoa01c577f2c2018-08-31 09:22:23 +010022#include <boost/polymorphic_pointer_cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000023
24namespace armnn
25{
26
Derek Lambertic81855f2019-06-13 17:34:19 +010027class NeonTensorHandle : public IAclTensorHandle
telsoa014fcda012018-03-09 14:13:49 +000028{
29public:
30 NeonTensorHandle(const TensorInfo& tensorInfo)
David Monahan3fb7e102019-08-20 11:25:29 +010031 : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc)),
32 m_Imported(false),
33 m_IsImportEnabled(false)
telsoa014fcda012018-03-09 14:13:49 +000034 {
35 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
36 }
37
David Monahan3fb7e102019-08-20 11:25:29 +010038 NeonTensorHandle(const TensorInfo& tensorInfo,
39 DataLayout dataLayout,
40 MemorySourceFlags importFlags = static_cast<MemorySourceFlags>(MemorySource::Malloc))
41 : m_ImportFlags(importFlags),
42 m_Imported(false),
43 m_IsImportEnabled(false)
44
Francis Murtagh351d13d2018-09-24 15:01:18 +010045 {
46 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
47 }
48
telsoa014fcda012018-03-09 14:13:49 +000049 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
50 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +010051
telsoa014fcda012018-03-09 14:13:49 +000052 virtual void Allocate() override
53 {
David Monahan3fb7e102019-08-20 11:25:29 +010054 // If we have enabled Importing, don't Allocate the tensor
55 if (!m_IsImportEnabled)
56 {
57 armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
58 }
telsoa014fcda012018-03-09 14:13:49 +000059 };
60
telsoa01c577f2c2018-08-31 09:22:23 +010061 virtual void Manage() override
62 {
David Monahan3fb7e102019-08-20 11:25:29 +010063 // If we have enabled Importing, don't manage the tensor
64 if (!m_IsImportEnabled)
65 {
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010066 ARMNN_ASSERT(m_MemoryGroup != nullptr);
David Monahan3fb7e102019-08-20 11:25:29 +010067 m_MemoryGroup->manage(&m_Tensor);
68 }
telsoa01c577f2c2018-08-31 09:22:23 +010069 }
70
telsoa01c577f2c2018-08-31 09:22:23 +010071 virtual ITensorHandle* GetParent() const override { return nullptr; }
72
telsoa014fcda012018-03-09 14:13:49 +000073 virtual arm_compute::DataType GetDataType() const override
74 {
75 return m_Tensor.info()->data_type();
76 }
77
telsoa01c577f2c2018-08-31 09:22:23 +010078 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
79 {
80 m_MemoryGroup = boost::polymorphic_pointer_downcast<arm_compute::MemoryGroup>(memoryGroup);
81 }
82
83 virtual const void* Map(bool /* blocking = true */) const override
84 {
85 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
86 }
telsoa01c577f2c2018-08-31 09:22:23 +010087
David Monahan3fb7e102019-08-20 11:25:29 +010088 virtual void Unmap() const override {}
telsoa01c577f2c2018-08-31 09:22:23 +010089
90 TensorShape GetStrides() const override
91 {
92 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
93 }
94
95 TensorShape GetShape() const override
96 {
97 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
98 }
99
David Monahan3fb7e102019-08-20 11:25:29 +0100100 void SetImportFlags(MemorySourceFlags importFlags)
101 {
102 m_ImportFlags = importFlags;
103 }
104
105 MemorySourceFlags GetImportFlags() const override
106 {
107 return m_ImportFlags;
108 }
109
110 void SetImportEnabledFlag(bool importEnabledFlag)
111 {
112 m_IsImportEnabled = importEnabledFlag;
113 }
114
115 virtual bool Import(void* memory, MemorySource source) override
116 {
117 if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
118 {
119 if (source == MemorySource::Malloc && m_IsImportEnabled)
120 {
121 // Checks the 16 byte memory alignment
122 constexpr uintptr_t alignment = sizeof(size_t);
123 if (reinterpret_cast<uintptr_t>(memory) % alignment)
124 {
125 throw MemoryImportException("NeonTensorHandle::Import Attempting to import unaligned memory");
126 }
127
128 // m_Tensor not yet Allocated
129 if (!m_Imported && !m_Tensor.buffer())
130 {
131 arm_compute::Status status = m_Tensor.allocator()->import_memory(memory);
132 // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
133 // with the Status error message
134 m_Imported = bool(status);
135 if (!m_Imported)
136 {
137 throw MemoryImportException(status.error_description());
138 }
139 return m_Imported;
140 }
141
142 // m_Tensor.buffer() initially allocated with Allocate().
143 if (!m_Imported && m_Tensor.buffer())
144 {
145 throw MemoryImportException(
146 "NeonTensorHandle::Import Attempting to import on an already allocated tensor");
147 }
148
149 // m_Tensor.buffer() previously imported.
150 if (m_Imported)
151 {
152 arm_compute::Status status = m_Tensor.allocator()->import_memory(memory);
153 // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
154 // with the Status error message
155 m_Imported = bool(status);
156 if (!m_Imported)
157 {
158 throw MemoryImportException(status.error_description());
159 }
160 return m_Imported;
161 }
162 }
163 }
164 return false;
165 }
166
telsoa014fcda012018-03-09 14:13:49 +0000167private:
David Beck09e2f272018-10-30 11:38:41 +0000168 // Only used for testing
169 void CopyOutTo(void* memory) const override
170 {
171 switch (this->GetDataType())
172 {
173 case arm_compute::DataType::F32:
174 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
175 static_cast<float*>(memory));
176 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000177 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000178 case arm_compute::DataType::QASYMM8:
179 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
180 static_cast<uint8_t*>(memory));
181 break;
Narumol Prangnawarat250d3922020-03-30 16:11:04 +0100182 case arm_compute::DataType::BFLOAT16:
183 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
184 static_cast<armnn::BFloat16*>(memory));
185 break;
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100186 case arm_compute::DataType::F16:
187 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
188 static_cast<armnn::Half*>(memory));
189 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100190 case arm_compute::DataType::S16:
191 case arm_compute::DataType::QSYMM16:
192 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
193 static_cast<int16_t*>(memory));
194 break;
James Conroyd47a0642019-09-17 14:22:06 +0100195 case arm_compute::DataType::S32:
196 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
197 static_cast<int32_t*>(memory));
198 break;
David Beck09e2f272018-10-30 11:38:41 +0000199 default:
200 {
201 throw armnn::UnimplementedException();
202 }
203 }
204 }
205
206 // Only used for testing
207 void CopyInFrom(const void* memory) override
208 {
209 switch (this->GetDataType())
210 {
211 case arm_compute::DataType::F32:
212 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
213 this->GetTensor());
214 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000215 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000216 case arm_compute::DataType::QASYMM8:
217 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
218 this->GetTensor());
219 break;
Narumol Prangnawarat250d3922020-03-30 16:11:04 +0100220 case arm_compute::DataType::BFLOAT16:
221 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::BFloat16*>(memory),
222 this->GetTensor());
223 break;
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100224 case arm_compute::DataType::F16:
225 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
226 this->GetTensor());
227 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100228 case arm_compute::DataType::S16:
229 case arm_compute::DataType::QSYMM16:
230 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
231 this->GetTensor());
232 break;
James Conroyd47a0642019-09-17 14:22:06 +0100233 case arm_compute::DataType::S32:
234 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
235 this->GetTensor());
236 break;
David Beck09e2f272018-10-30 11:38:41 +0000237 default:
238 {
239 throw armnn::UnimplementedException();
240 }
241 }
242 }
243
telsoa014fcda012018-03-09 14:13:49 +0000244 arm_compute::Tensor m_Tensor;
telsoa01c577f2c2018-08-31 09:22:23 +0100245 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
David Monahan3fb7e102019-08-20 11:25:29 +0100246 MemorySourceFlags m_ImportFlags;
247 bool m_Imported;
248 bool m_IsImportEnabled;
telsoa014fcda012018-03-09 14:13:49 +0000249};
250
Derek Lambertic81855f2019-06-13 17:34:19 +0100251class NeonSubTensorHandle : public IAclTensorHandle
telsoa014fcda012018-03-09 14:13:49 +0000252{
253public:
Derek Lambertic81855f2019-06-13 17:34:19 +0100254 NeonSubTensorHandle(IAclTensorHandle* parent,
telsoa01c577f2c2018-08-31 09:22:23 +0100255 const arm_compute::TensorShape& shape,
256 const arm_compute::Coordinates& coords)
257 : m_Tensor(&parent->GetTensor(), shape, coords)
telsoa014fcda012018-03-09 14:13:49 +0000258 {
telsoa01c577f2c2018-08-31 09:22:23 +0100259 parentHandle = parent;
telsoa014fcda012018-03-09 14:13:49 +0000260 }
261
262 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
263 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +0100264
265 virtual void Allocate() override {}
266 virtual void Manage() override {}
telsoa014fcda012018-03-09 14:13:49 +0000267
telsoa01c577f2c2018-08-31 09:22:23 +0100268 virtual ITensorHandle* GetParent() const override { return parentHandle; }
269
telsoa014fcda012018-03-09 14:13:49 +0000270 virtual arm_compute::DataType GetDataType() const override
271 {
272 return m_Tensor.info()->data_type();
273 }
274
telsoa01c577f2c2018-08-31 09:22:23 +0100275 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
276
277 virtual const void* Map(bool /* blocking = true */) const override
278 {
279 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
280 }
281 virtual void Unmap() const override {}
282
283 TensorShape GetStrides() const override
284 {
285 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
286 }
287
288 TensorShape GetShape() const override
289 {
290 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
291 }
David Beck09e2f272018-10-30 11:38:41 +0000292
telsoa014fcda012018-03-09 14:13:49 +0000293private:
David Beck09e2f272018-10-30 11:38:41 +0000294 // Only used for testing
295 void CopyOutTo(void* memory) const override
296 {
297 switch (this->GetDataType())
298 {
299 case arm_compute::DataType::F32:
300 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
301 static_cast<float*>(memory));
302 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000303 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000304 case arm_compute::DataType::QASYMM8:
305 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
306 static_cast<uint8_t*>(memory));
307 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100308 case arm_compute::DataType::S16:
309 case arm_compute::DataType::QSYMM16:
310 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
311 static_cast<int16_t*>(memory));
312 break;
James Conroyd47a0642019-09-17 14:22:06 +0100313 case arm_compute::DataType::S32:
314 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
315 static_cast<int32_t*>(memory));
316 break;
David Beck09e2f272018-10-30 11:38:41 +0000317 default:
318 {
319 throw armnn::UnimplementedException();
320 }
321 }
322 }
323
324 // Only used for testing
325 void CopyInFrom(const void* memory) override
326 {
327 switch (this->GetDataType())
328 {
329 case arm_compute::DataType::F32:
330 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
331 this->GetTensor());
332 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000333 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000334 case arm_compute::DataType::QASYMM8:
335 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
336 this->GetTensor());
337 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100338 case arm_compute::DataType::S16:
339 case arm_compute::DataType::QSYMM16:
340 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
341 this->GetTensor());
342 break;
James Conroyd47a0642019-09-17 14:22:06 +0100343 case arm_compute::DataType::S32:
344 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
345 this->GetTensor());
346 break;
David Beck09e2f272018-10-30 11:38:41 +0000347 default:
348 {
349 throw armnn::UnimplementedException();
350 }
351 }
352 }
353
telsoa01c577f2c2018-08-31 09:22:23 +0100354 arm_compute::SubTensor m_Tensor;
355 ITensorHandle* parentHandle = nullptr;
telsoa014fcda012018-03-09 14:13:49 +0000356};
357
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100358} // namespace armnn