blob: 2e9be11be165faf076eb396216cbd7eca716cb82 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
Aron Virginas-Tar99836d32019-09-30 16:34:31 +01007#include <Half.hpp>
8
Derek Lambertic81855f2019-06-13 17:34:19 +01009#include <aclCommon/ArmComputeTensorHandle.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000010#include <aclCommon/ArmComputeTensorUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011
telsoa01c577f2c2018-08-31 09:22:23 +010012#include <arm_compute/runtime/MemoryGroup.h>
13#include <arm_compute/runtime/IMemoryGroup.h>
telsoa014fcda012018-03-09 14:13:49 +000014#include <arm_compute/runtime/Tensor.h>
15#include <arm_compute/runtime/SubTensor.h>
16#include <arm_compute/core/TensorShape.h>
17#include <arm_compute/core/Coordinates.h>
18
telsoa01c577f2c2018-08-31 09:22:23 +010019#include <boost/polymorphic_pointer_cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000020
21namespace armnn
22{
23
Derek Lambertic81855f2019-06-13 17:34:19 +010024class NeonTensorHandle : public IAclTensorHandle
telsoa014fcda012018-03-09 14:13:49 +000025{
26public:
27 NeonTensorHandle(const TensorInfo& tensorInfo)
David Monahan3fb7e102019-08-20 11:25:29 +010028 : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc)),
29 m_Imported(false),
30 m_IsImportEnabled(false)
telsoa014fcda012018-03-09 14:13:49 +000031 {
32 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
33 }
34
David Monahan3fb7e102019-08-20 11:25:29 +010035 NeonTensorHandle(const TensorInfo& tensorInfo,
36 DataLayout dataLayout,
37 MemorySourceFlags importFlags = static_cast<MemorySourceFlags>(MemorySource::Malloc))
38 : m_ImportFlags(importFlags),
39 m_Imported(false),
40 m_IsImportEnabled(false)
41
Francis Murtagh351d13d2018-09-24 15:01:18 +010042 {
43 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
44 }
45
telsoa014fcda012018-03-09 14:13:49 +000046 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
47 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +010048
telsoa014fcda012018-03-09 14:13:49 +000049 virtual void Allocate() override
50 {
David Monahan3fb7e102019-08-20 11:25:29 +010051 // If we have enabled Importing, don't Allocate the tensor
52 if (!m_IsImportEnabled)
53 {
54 armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
55 }
telsoa014fcda012018-03-09 14:13:49 +000056 };
57
telsoa01c577f2c2018-08-31 09:22:23 +010058 virtual void Manage() override
59 {
David Monahan3fb7e102019-08-20 11:25:29 +010060 // If we have enabled Importing, don't manage the tensor
61 if (!m_IsImportEnabled)
62 {
63 BOOST_ASSERT(m_MemoryGroup != nullptr);
64 m_MemoryGroup->manage(&m_Tensor);
65 }
telsoa01c577f2c2018-08-31 09:22:23 +010066 }
67
telsoa01c577f2c2018-08-31 09:22:23 +010068 virtual ITensorHandle* GetParent() const override { return nullptr; }
69
telsoa014fcda012018-03-09 14:13:49 +000070 virtual arm_compute::DataType GetDataType() const override
71 {
72 return m_Tensor.info()->data_type();
73 }
74
telsoa01c577f2c2018-08-31 09:22:23 +010075 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
76 {
77 m_MemoryGroup = boost::polymorphic_pointer_downcast<arm_compute::MemoryGroup>(memoryGroup);
78 }
79
80 virtual const void* Map(bool /* blocking = true */) const override
81 {
82 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
83 }
telsoa01c577f2c2018-08-31 09:22:23 +010084
David Monahan3fb7e102019-08-20 11:25:29 +010085 virtual void Unmap() const override {}
telsoa01c577f2c2018-08-31 09:22:23 +010086
87 TensorShape GetStrides() const override
88 {
89 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
90 }
91
92 TensorShape GetShape() const override
93 {
94 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
95 }
96
David Monahan3fb7e102019-08-20 11:25:29 +010097 void SetImportFlags(MemorySourceFlags importFlags)
98 {
99 m_ImportFlags = importFlags;
100 }
101
102 MemorySourceFlags GetImportFlags() const override
103 {
104 return m_ImportFlags;
105 }
106
107 void SetImportEnabledFlag(bool importEnabledFlag)
108 {
109 m_IsImportEnabled = importEnabledFlag;
110 }
111
112 virtual bool Import(void* memory, MemorySource source) override
113 {
114 if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
115 {
116 if (source == MemorySource::Malloc && m_IsImportEnabled)
117 {
118 // Checks the 16 byte memory alignment
119 constexpr uintptr_t alignment = sizeof(size_t);
120 if (reinterpret_cast<uintptr_t>(memory) % alignment)
121 {
122 throw MemoryImportException("NeonTensorHandle::Import Attempting to import unaligned memory");
123 }
124
125 // m_Tensor not yet Allocated
126 if (!m_Imported && !m_Tensor.buffer())
127 {
128 arm_compute::Status status = m_Tensor.allocator()->import_memory(memory);
129 // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
130 // with the Status error message
131 m_Imported = bool(status);
132 if (!m_Imported)
133 {
134 throw MemoryImportException(status.error_description());
135 }
136 return m_Imported;
137 }
138
139 // m_Tensor.buffer() initially allocated with Allocate().
140 if (!m_Imported && m_Tensor.buffer())
141 {
142 throw MemoryImportException(
143 "NeonTensorHandle::Import Attempting to import on an already allocated tensor");
144 }
145
146 // m_Tensor.buffer() previously imported.
147 if (m_Imported)
148 {
149 arm_compute::Status status = m_Tensor.allocator()->import_memory(memory);
150 // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
151 // with the Status error message
152 m_Imported = bool(status);
153 if (!m_Imported)
154 {
155 throw MemoryImportException(status.error_description());
156 }
157 return m_Imported;
158 }
159 }
160 }
161 return false;
162 }
163
telsoa014fcda012018-03-09 14:13:49 +0000164private:
David Beck09e2f272018-10-30 11:38:41 +0000165 // Only used for testing
166 void CopyOutTo(void* memory) const override
167 {
168 switch (this->GetDataType())
169 {
170 case arm_compute::DataType::F32:
171 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
172 static_cast<float*>(memory));
173 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000174 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000175 case arm_compute::DataType::QASYMM8:
176 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
177 static_cast<uint8_t*>(memory));
178 break;
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100179 case arm_compute::DataType::F16:
180 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
181 static_cast<armnn::Half*>(memory));
182 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100183 case arm_compute::DataType::S16:
184 case arm_compute::DataType::QSYMM16:
185 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
186 static_cast<int16_t*>(memory));
187 break;
James Conroyd47a0642019-09-17 14:22:06 +0100188 case arm_compute::DataType::S32:
189 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
190 static_cast<int32_t*>(memory));
191 break;
David Beck09e2f272018-10-30 11:38:41 +0000192 default:
193 {
194 throw armnn::UnimplementedException();
195 }
196 }
197 }
198
199 // Only used for testing
200 void CopyInFrom(const void* memory) override
201 {
202 switch (this->GetDataType())
203 {
204 case arm_compute::DataType::F32:
205 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
206 this->GetTensor());
207 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000208 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000209 case arm_compute::DataType::QASYMM8:
210 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
211 this->GetTensor());
212 break;
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100213 case arm_compute::DataType::F16:
214 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
215 this->GetTensor());
216 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100217 case arm_compute::DataType::S16:
218 case arm_compute::DataType::QSYMM16:
219 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
220 this->GetTensor());
221 break;
James Conroyd47a0642019-09-17 14:22:06 +0100222 case arm_compute::DataType::S32:
223 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
224 this->GetTensor());
225 break;
David Beck09e2f272018-10-30 11:38:41 +0000226 default:
227 {
228 throw armnn::UnimplementedException();
229 }
230 }
231 }
232
telsoa014fcda012018-03-09 14:13:49 +0000233 arm_compute::Tensor m_Tensor;
telsoa01c577f2c2018-08-31 09:22:23 +0100234 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
David Monahan3fb7e102019-08-20 11:25:29 +0100235 MemorySourceFlags m_ImportFlags;
236 bool m_Imported;
237 bool m_IsImportEnabled;
telsoa014fcda012018-03-09 14:13:49 +0000238};
239
Derek Lambertic81855f2019-06-13 17:34:19 +0100240class NeonSubTensorHandle : public IAclTensorHandle
telsoa014fcda012018-03-09 14:13:49 +0000241{
242public:
Derek Lambertic81855f2019-06-13 17:34:19 +0100243 NeonSubTensorHandle(IAclTensorHandle* parent,
telsoa01c577f2c2018-08-31 09:22:23 +0100244 const arm_compute::TensorShape& shape,
245 const arm_compute::Coordinates& coords)
246 : m_Tensor(&parent->GetTensor(), shape, coords)
telsoa014fcda012018-03-09 14:13:49 +0000247 {
telsoa01c577f2c2018-08-31 09:22:23 +0100248 parentHandle = parent;
telsoa014fcda012018-03-09 14:13:49 +0000249 }
250
251 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
252 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +0100253
254 virtual void Allocate() override {}
255 virtual void Manage() override {}
telsoa014fcda012018-03-09 14:13:49 +0000256
telsoa01c577f2c2018-08-31 09:22:23 +0100257 virtual ITensorHandle* GetParent() const override { return parentHandle; }
258
telsoa014fcda012018-03-09 14:13:49 +0000259 virtual arm_compute::DataType GetDataType() const override
260 {
261 return m_Tensor.info()->data_type();
262 }
263
telsoa01c577f2c2018-08-31 09:22:23 +0100264 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
265
266 virtual const void* Map(bool /* blocking = true */) const override
267 {
268 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
269 }
270 virtual void Unmap() const override {}
271
272 TensorShape GetStrides() const override
273 {
274 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
275 }
276
277 TensorShape GetShape() const override
278 {
279 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
280 }
David Beck09e2f272018-10-30 11:38:41 +0000281
telsoa014fcda012018-03-09 14:13:49 +0000282private:
David Beck09e2f272018-10-30 11:38:41 +0000283 // Only used for testing
284 void CopyOutTo(void* memory) const override
285 {
286 switch (this->GetDataType())
287 {
288 case arm_compute::DataType::F32:
289 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
290 static_cast<float*>(memory));
291 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000292 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000293 case arm_compute::DataType::QASYMM8:
294 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
295 static_cast<uint8_t*>(memory));
296 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100297 case arm_compute::DataType::S16:
298 case arm_compute::DataType::QSYMM16:
299 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
300 static_cast<int16_t*>(memory));
301 break;
James Conroyd47a0642019-09-17 14:22:06 +0100302 case arm_compute::DataType::S32:
303 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
304 static_cast<int32_t*>(memory));
305 break;
David Beck09e2f272018-10-30 11:38:41 +0000306 default:
307 {
308 throw armnn::UnimplementedException();
309 }
310 }
311 }
312
313 // Only used for testing
314 void CopyInFrom(const void* memory) override
315 {
316 switch (this->GetDataType())
317 {
318 case arm_compute::DataType::F32:
319 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
320 this->GetTensor());
321 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000322 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000323 case arm_compute::DataType::QASYMM8:
324 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
325 this->GetTensor());
326 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100327 case arm_compute::DataType::S16:
328 case arm_compute::DataType::QSYMM16:
329 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
330 this->GetTensor());
331 break;
James Conroyd47a0642019-09-17 14:22:06 +0100332 case arm_compute::DataType::S32:
333 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
334 this->GetTensor());
335 break;
David Beck09e2f272018-10-30 11:38:41 +0000336 default:
337 {
338 throw armnn::UnimplementedException();
339 }
340 }
341 }
342
telsoa01c577f2c2018-08-31 09:22:23 +0100343 arm_compute::SubTensor m_Tensor;
344 ITensorHandle* parentHandle = nullptr;
telsoa014fcda012018-03-09 14:13:49 +0000345};
346
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100347} // namespace armnn