blob: c3662c1211bf0e3bd9b3baa1d074549e333039d6 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00007#include <backendsCommon/OutputHandler.hpp>
Derek Lambertic81855f2019-06-13 17:34:19 +01008#include <aclCommon/ArmComputeTensorHandle.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <aclCommon/ArmComputeTensorUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000010
telsoa01c577f2c2018-08-31 09:22:23 +010011#include <arm_compute/runtime/MemoryGroup.h>
12#include <arm_compute/runtime/IMemoryGroup.h>
telsoa014fcda012018-03-09 14:13:49 +000013#include <arm_compute/runtime/Tensor.h>
14#include <arm_compute/runtime/SubTensor.h>
15#include <arm_compute/core/TensorShape.h>
16#include <arm_compute/core/Coordinates.h>
17
telsoa01c577f2c2018-08-31 09:22:23 +010018#include <boost/polymorphic_pointer_cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000019
20namespace armnn
21{
22
Derek Lambertic81855f2019-06-13 17:34:19 +010023class NeonTensorHandle : public IAclTensorHandle
telsoa014fcda012018-03-09 14:13:49 +000024{
25public:
26 NeonTensorHandle(const TensorInfo& tensorInfo)
David Monahan3fb7e102019-08-20 11:25:29 +010027 : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc)),
28 m_Imported(false),
29 m_IsImportEnabled(false)
telsoa014fcda012018-03-09 14:13:49 +000030 {
31 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
32 }
33
David Monahan3fb7e102019-08-20 11:25:29 +010034 NeonTensorHandle(const TensorInfo& tensorInfo,
35 DataLayout dataLayout,
36 MemorySourceFlags importFlags = static_cast<MemorySourceFlags>(MemorySource::Malloc))
37 : m_ImportFlags(importFlags),
38 m_Imported(false),
39 m_IsImportEnabled(false)
40
Francis Murtagh351d13d2018-09-24 15:01:18 +010041 {
42 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
43 }
44
telsoa014fcda012018-03-09 14:13:49 +000045 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
46 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +010047
telsoa014fcda012018-03-09 14:13:49 +000048 virtual void Allocate() override
49 {
David Monahan3fb7e102019-08-20 11:25:29 +010050 // If we have enabled Importing, don't Allocate the tensor
51 if (!m_IsImportEnabled)
52 {
53 armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
54 }
telsoa014fcda012018-03-09 14:13:49 +000055 };
56
telsoa01c577f2c2018-08-31 09:22:23 +010057 virtual void Manage() override
58 {
David Monahan3fb7e102019-08-20 11:25:29 +010059 // If we have enabled Importing, don't manage the tensor
60 if (!m_IsImportEnabled)
61 {
62 BOOST_ASSERT(m_MemoryGroup != nullptr);
63 m_MemoryGroup->manage(&m_Tensor);
64 }
telsoa01c577f2c2018-08-31 09:22:23 +010065 }
66
telsoa01c577f2c2018-08-31 09:22:23 +010067 virtual ITensorHandle* GetParent() const override { return nullptr; }
68
telsoa014fcda012018-03-09 14:13:49 +000069 virtual arm_compute::DataType GetDataType() const override
70 {
71 return m_Tensor.info()->data_type();
72 }
73
telsoa01c577f2c2018-08-31 09:22:23 +010074 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
75 {
76 m_MemoryGroup = boost::polymorphic_pointer_downcast<arm_compute::MemoryGroup>(memoryGroup);
77 }
78
79 virtual const void* Map(bool /* blocking = true */) const override
80 {
81 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
82 }
telsoa01c577f2c2018-08-31 09:22:23 +010083
David Monahan3fb7e102019-08-20 11:25:29 +010084 virtual void Unmap() const override {}
telsoa01c577f2c2018-08-31 09:22:23 +010085
86 TensorShape GetStrides() const override
87 {
88 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
89 }
90
91 TensorShape GetShape() const override
92 {
93 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
94 }
95
David Monahan3fb7e102019-08-20 11:25:29 +010096 void SetImportFlags(MemorySourceFlags importFlags)
97 {
98 m_ImportFlags = importFlags;
99 }
100
101 MemorySourceFlags GetImportFlags() const override
102 {
103 return m_ImportFlags;
104 }
105
106 void SetImportEnabledFlag(bool importEnabledFlag)
107 {
108 m_IsImportEnabled = importEnabledFlag;
109 }
110
111 virtual bool Import(void* memory, MemorySource source) override
112 {
113 if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
114 {
115 if (source == MemorySource::Malloc && m_IsImportEnabled)
116 {
117 // Checks the 16 byte memory alignment
118 constexpr uintptr_t alignment = sizeof(size_t);
119 if (reinterpret_cast<uintptr_t>(memory) % alignment)
120 {
121 throw MemoryImportException("NeonTensorHandle::Import Attempting to import unaligned memory");
122 }
123
124 // m_Tensor not yet Allocated
125 if (!m_Imported && !m_Tensor.buffer())
126 {
127 arm_compute::Status status = m_Tensor.allocator()->import_memory(memory);
128 // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
129 // with the Status error message
130 m_Imported = bool(status);
131 if (!m_Imported)
132 {
133 throw MemoryImportException(status.error_description());
134 }
135 return m_Imported;
136 }
137
138 // m_Tensor.buffer() initially allocated with Allocate().
139 if (!m_Imported && m_Tensor.buffer())
140 {
141 throw MemoryImportException(
142 "NeonTensorHandle::Import Attempting to import on an already allocated tensor");
143 }
144
145 // m_Tensor.buffer() previously imported.
146 if (m_Imported)
147 {
148 arm_compute::Status status = m_Tensor.allocator()->import_memory(memory);
149 // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
150 // with the Status error message
151 m_Imported = bool(status);
152 if (!m_Imported)
153 {
154 throw MemoryImportException(status.error_description());
155 }
156 return m_Imported;
157 }
158 }
159 }
160 return false;
161 }
162
telsoa014fcda012018-03-09 14:13:49 +0000163private:
David Beck09e2f272018-10-30 11:38:41 +0000164 // Only used for testing
165 void CopyOutTo(void* memory) const override
166 {
167 switch (this->GetDataType())
168 {
169 case arm_compute::DataType::F32:
170 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
171 static_cast<float*>(memory));
172 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000173 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000174 case arm_compute::DataType::QASYMM8:
175 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
176 static_cast<uint8_t*>(memory));
177 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100178 case arm_compute::DataType::S16:
179 case arm_compute::DataType::QSYMM16:
180 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
181 static_cast<int16_t*>(memory));
182 break;
David Beck09e2f272018-10-30 11:38:41 +0000183 default:
184 {
185 throw armnn::UnimplementedException();
186 }
187 }
188 }
189
190 // Only used for testing
191 void CopyInFrom(const void* memory) override
192 {
193 switch (this->GetDataType())
194 {
195 case arm_compute::DataType::F32:
196 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
197 this->GetTensor());
198 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000199 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000200 case arm_compute::DataType::QASYMM8:
201 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
202 this->GetTensor());
203 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100204 case arm_compute::DataType::S16:
205 case arm_compute::DataType::QSYMM16:
206 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
207 this->GetTensor());
208 break;
David Beck09e2f272018-10-30 11:38:41 +0000209 default:
210 {
211 throw armnn::UnimplementedException();
212 }
213 }
214 }
215
telsoa014fcda012018-03-09 14:13:49 +0000216 arm_compute::Tensor m_Tensor;
telsoa01c577f2c2018-08-31 09:22:23 +0100217 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
David Monahan3fb7e102019-08-20 11:25:29 +0100218 MemorySourceFlags m_ImportFlags;
219 bool m_Imported;
220 bool m_IsImportEnabled;
telsoa014fcda012018-03-09 14:13:49 +0000221};
222
Derek Lambertic81855f2019-06-13 17:34:19 +0100223class NeonSubTensorHandle : public IAclTensorHandle
telsoa014fcda012018-03-09 14:13:49 +0000224{
225public:
Derek Lambertic81855f2019-06-13 17:34:19 +0100226 NeonSubTensorHandle(IAclTensorHandle* parent,
telsoa01c577f2c2018-08-31 09:22:23 +0100227 const arm_compute::TensorShape& shape,
228 const arm_compute::Coordinates& coords)
229 : m_Tensor(&parent->GetTensor(), shape, coords)
telsoa014fcda012018-03-09 14:13:49 +0000230 {
telsoa01c577f2c2018-08-31 09:22:23 +0100231 parentHandle = parent;
telsoa014fcda012018-03-09 14:13:49 +0000232 }
233
234 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
235 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +0100236
237 virtual void Allocate() override {}
238 virtual void Manage() override {}
telsoa014fcda012018-03-09 14:13:49 +0000239
telsoa01c577f2c2018-08-31 09:22:23 +0100240 virtual ITensorHandle* GetParent() const override { return parentHandle; }
241
telsoa014fcda012018-03-09 14:13:49 +0000242 virtual arm_compute::DataType GetDataType() const override
243 {
244 return m_Tensor.info()->data_type();
245 }
246
telsoa01c577f2c2018-08-31 09:22:23 +0100247 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
248
249 virtual const void* Map(bool /* blocking = true */) const override
250 {
251 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
252 }
253 virtual void Unmap() const override {}
254
255 TensorShape GetStrides() const override
256 {
257 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
258 }
259
260 TensorShape GetShape() const override
261 {
262 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
263 }
David Beck09e2f272018-10-30 11:38:41 +0000264
telsoa014fcda012018-03-09 14:13:49 +0000265private:
David Beck09e2f272018-10-30 11:38:41 +0000266 // Only used for testing
267 void CopyOutTo(void* memory) const override
268 {
269 switch (this->GetDataType())
270 {
271 case arm_compute::DataType::F32:
272 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
273 static_cast<float*>(memory));
274 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000275 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000276 case arm_compute::DataType::QASYMM8:
277 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
278 static_cast<uint8_t*>(memory));
279 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100280 case arm_compute::DataType::S16:
281 case arm_compute::DataType::QSYMM16:
282 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
283 static_cast<int16_t*>(memory));
284 break;
David Beck09e2f272018-10-30 11:38:41 +0000285 default:
286 {
287 throw armnn::UnimplementedException();
288 }
289 }
290 }
291
292 // Only used for testing
293 void CopyInFrom(const void* memory) override
294 {
295 switch (this->GetDataType())
296 {
297 case arm_compute::DataType::F32:
298 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
299 this->GetTensor());
300 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000301 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000302 case arm_compute::DataType::QASYMM8:
303 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
304 this->GetTensor());
305 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100306 case arm_compute::DataType::S16:
307 case arm_compute::DataType::QSYMM16:
308 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
309 this->GetTensor());
310 break;
David Beck09e2f272018-10-30 11:38:41 +0000311 default:
312 {
313 throw armnn::UnimplementedException();
314 }
315 }
316 }
317
telsoa01c577f2c2018-08-31 09:22:23 +0100318 arm_compute::SubTensor m_Tensor;
319 ITensorHandle* parentHandle = nullptr;
telsoa014fcda012018-03-09 14:13:49 +0000320};
321
David Beck09e2f272018-10-30 11:38:41 +0000322} // namespace armnn