blob: ca5bfb04b1e44199e76fc0642be03558ff98e5dd [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
Aron Virginas-Tar99836d32019-09-30 16:34:31 +01007#include <Half.hpp>
8
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <backendsCommon/OutputHandler.hpp>
Derek Lambertic81855f2019-06-13 17:34:19 +010010#include <aclCommon/ArmComputeTensorHandle.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000011#include <aclCommon/ArmComputeTensorUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000012
telsoa01c577f2c2018-08-31 09:22:23 +010013#include <arm_compute/runtime/MemoryGroup.h>
14#include <arm_compute/runtime/IMemoryGroup.h>
telsoa014fcda012018-03-09 14:13:49 +000015#include <arm_compute/runtime/Tensor.h>
16#include <arm_compute/runtime/SubTensor.h>
17#include <arm_compute/core/TensorShape.h>
18#include <arm_compute/core/Coordinates.h>
19
telsoa01c577f2c2018-08-31 09:22:23 +010020#include <boost/polymorphic_pointer_cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000021
22namespace armnn
23{
24
Derek Lambertic81855f2019-06-13 17:34:19 +010025class NeonTensorHandle : public IAclTensorHandle
telsoa014fcda012018-03-09 14:13:49 +000026{
27public:
28 NeonTensorHandle(const TensorInfo& tensorInfo)
David Monahan3fb7e102019-08-20 11:25:29 +010029 : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc)),
30 m_Imported(false),
31 m_IsImportEnabled(false)
telsoa014fcda012018-03-09 14:13:49 +000032 {
33 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
34 }
35
David Monahan3fb7e102019-08-20 11:25:29 +010036 NeonTensorHandle(const TensorInfo& tensorInfo,
37 DataLayout dataLayout,
38 MemorySourceFlags importFlags = static_cast<MemorySourceFlags>(MemorySource::Malloc))
39 : m_ImportFlags(importFlags),
40 m_Imported(false),
41 m_IsImportEnabled(false)
42
Francis Murtagh351d13d2018-09-24 15:01:18 +010043 {
44 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
45 }
46
telsoa014fcda012018-03-09 14:13:49 +000047 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
48 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +010049
telsoa014fcda012018-03-09 14:13:49 +000050 virtual void Allocate() override
51 {
David Monahan3fb7e102019-08-20 11:25:29 +010052 // If we have enabled Importing, don't Allocate the tensor
53 if (!m_IsImportEnabled)
54 {
55 armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
56 }
telsoa014fcda012018-03-09 14:13:49 +000057 };
58
telsoa01c577f2c2018-08-31 09:22:23 +010059 virtual void Manage() override
60 {
David Monahan3fb7e102019-08-20 11:25:29 +010061 // If we have enabled Importing, don't manage the tensor
62 if (!m_IsImportEnabled)
63 {
64 BOOST_ASSERT(m_MemoryGroup != nullptr);
65 m_MemoryGroup->manage(&m_Tensor);
66 }
telsoa01c577f2c2018-08-31 09:22:23 +010067 }
68
telsoa01c577f2c2018-08-31 09:22:23 +010069 virtual ITensorHandle* GetParent() const override { return nullptr; }
70
telsoa014fcda012018-03-09 14:13:49 +000071 virtual arm_compute::DataType GetDataType() const override
72 {
73 return m_Tensor.info()->data_type();
74 }
75
telsoa01c577f2c2018-08-31 09:22:23 +010076 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
77 {
78 m_MemoryGroup = boost::polymorphic_pointer_downcast<arm_compute::MemoryGroup>(memoryGroup);
79 }
80
81 virtual const void* Map(bool /* blocking = true */) const override
82 {
83 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
84 }
telsoa01c577f2c2018-08-31 09:22:23 +010085
David Monahan3fb7e102019-08-20 11:25:29 +010086 virtual void Unmap() const override {}
telsoa01c577f2c2018-08-31 09:22:23 +010087
88 TensorShape GetStrides() const override
89 {
90 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
91 }
92
93 TensorShape GetShape() const override
94 {
95 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
96 }
97
David Monahan3fb7e102019-08-20 11:25:29 +010098 void SetImportFlags(MemorySourceFlags importFlags)
99 {
100 m_ImportFlags = importFlags;
101 }
102
103 MemorySourceFlags GetImportFlags() const override
104 {
105 return m_ImportFlags;
106 }
107
108 void SetImportEnabledFlag(bool importEnabledFlag)
109 {
110 m_IsImportEnabled = importEnabledFlag;
111 }
112
113 virtual bool Import(void* memory, MemorySource source) override
114 {
115 if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
116 {
117 if (source == MemorySource::Malloc && m_IsImportEnabled)
118 {
119 // Checks the 16 byte memory alignment
120 constexpr uintptr_t alignment = sizeof(size_t);
121 if (reinterpret_cast<uintptr_t>(memory) % alignment)
122 {
123 throw MemoryImportException("NeonTensorHandle::Import Attempting to import unaligned memory");
124 }
125
126 // m_Tensor not yet Allocated
127 if (!m_Imported && !m_Tensor.buffer())
128 {
129 arm_compute::Status status = m_Tensor.allocator()->import_memory(memory);
130 // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
131 // with the Status error message
132 m_Imported = bool(status);
133 if (!m_Imported)
134 {
135 throw MemoryImportException(status.error_description());
136 }
137 return m_Imported;
138 }
139
140 // m_Tensor.buffer() initially allocated with Allocate().
141 if (!m_Imported && m_Tensor.buffer())
142 {
143 throw MemoryImportException(
144 "NeonTensorHandle::Import Attempting to import on an already allocated tensor");
145 }
146
147 // m_Tensor.buffer() previously imported.
148 if (m_Imported)
149 {
150 arm_compute::Status status = m_Tensor.allocator()->import_memory(memory);
151 // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
152 // with the Status error message
153 m_Imported = bool(status);
154 if (!m_Imported)
155 {
156 throw MemoryImportException(status.error_description());
157 }
158 return m_Imported;
159 }
160 }
161 }
162 return false;
163 }
164
telsoa014fcda012018-03-09 14:13:49 +0000165private:
David Beck09e2f272018-10-30 11:38:41 +0000166 // Only used for testing
167 void CopyOutTo(void* memory) const override
168 {
169 switch (this->GetDataType())
170 {
171 case arm_compute::DataType::F32:
172 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
173 static_cast<float*>(memory));
174 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000175 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000176 case arm_compute::DataType::QASYMM8:
177 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
178 static_cast<uint8_t*>(memory));
179 break;
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100180 case arm_compute::DataType::F16:
181 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
182 static_cast<armnn::Half*>(memory));
183 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100184 case arm_compute::DataType::S16:
185 case arm_compute::DataType::QSYMM16:
186 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
187 static_cast<int16_t*>(memory));
188 break;
James Conroyd47a0642019-09-17 14:22:06 +0100189 case arm_compute::DataType::S32:
190 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
191 static_cast<int32_t*>(memory));
192 break;
David Beck09e2f272018-10-30 11:38:41 +0000193 default:
194 {
195 throw armnn::UnimplementedException();
196 }
197 }
198 }
199
200 // Only used for testing
201 void CopyInFrom(const void* memory) override
202 {
203 switch (this->GetDataType())
204 {
205 case arm_compute::DataType::F32:
206 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
207 this->GetTensor());
208 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000209 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000210 case arm_compute::DataType::QASYMM8:
211 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
212 this->GetTensor());
213 break;
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100214 case arm_compute::DataType::F16:
215 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
216 this->GetTensor());
217 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100218 case arm_compute::DataType::S16:
219 case arm_compute::DataType::QSYMM16:
220 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
221 this->GetTensor());
222 break;
James Conroyd47a0642019-09-17 14:22:06 +0100223 case arm_compute::DataType::S32:
224 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
225 this->GetTensor());
226 break;
David Beck09e2f272018-10-30 11:38:41 +0000227 default:
228 {
229 throw armnn::UnimplementedException();
230 }
231 }
232 }
233
telsoa014fcda012018-03-09 14:13:49 +0000234 arm_compute::Tensor m_Tensor;
telsoa01c577f2c2018-08-31 09:22:23 +0100235 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
David Monahan3fb7e102019-08-20 11:25:29 +0100236 MemorySourceFlags m_ImportFlags;
237 bool m_Imported;
238 bool m_IsImportEnabled;
telsoa014fcda012018-03-09 14:13:49 +0000239};
240
Derek Lambertic81855f2019-06-13 17:34:19 +0100241class NeonSubTensorHandle : public IAclTensorHandle
telsoa014fcda012018-03-09 14:13:49 +0000242{
243public:
Derek Lambertic81855f2019-06-13 17:34:19 +0100244 NeonSubTensorHandle(IAclTensorHandle* parent,
telsoa01c577f2c2018-08-31 09:22:23 +0100245 const arm_compute::TensorShape& shape,
246 const arm_compute::Coordinates& coords)
247 : m_Tensor(&parent->GetTensor(), shape, coords)
telsoa014fcda012018-03-09 14:13:49 +0000248 {
telsoa01c577f2c2018-08-31 09:22:23 +0100249 parentHandle = parent;
telsoa014fcda012018-03-09 14:13:49 +0000250 }
251
252 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
253 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +0100254
255 virtual void Allocate() override {}
256 virtual void Manage() override {}
telsoa014fcda012018-03-09 14:13:49 +0000257
telsoa01c577f2c2018-08-31 09:22:23 +0100258 virtual ITensorHandle* GetParent() const override { return parentHandle; }
259
telsoa014fcda012018-03-09 14:13:49 +0000260 virtual arm_compute::DataType GetDataType() const override
261 {
262 return m_Tensor.info()->data_type();
263 }
264
telsoa01c577f2c2018-08-31 09:22:23 +0100265 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
266
267 virtual const void* Map(bool /* blocking = true */) const override
268 {
269 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
270 }
271 virtual void Unmap() const override {}
272
273 TensorShape GetStrides() const override
274 {
275 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
276 }
277
278 TensorShape GetShape() const override
279 {
280 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
281 }
David Beck09e2f272018-10-30 11:38:41 +0000282
telsoa014fcda012018-03-09 14:13:49 +0000283private:
David Beck09e2f272018-10-30 11:38:41 +0000284 // Only used for testing
285 void CopyOutTo(void* memory) const override
286 {
287 switch (this->GetDataType())
288 {
289 case arm_compute::DataType::F32:
290 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
291 static_cast<float*>(memory));
292 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000293 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000294 case arm_compute::DataType::QASYMM8:
295 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
296 static_cast<uint8_t*>(memory));
297 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100298 case arm_compute::DataType::S16:
299 case arm_compute::DataType::QSYMM16:
300 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
301 static_cast<int16_t*>(memory));
302 break;
James Conroyd47a0642019-09-17 14:22:06 +0100303 case arm_compute::DataType::S32:
304 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
305 static_cast<int32_t*>(memory));
306 break;
David Beck09e2f272018-10-30 11:38:41 +0000307 default:
308 {
309 throw armnn::UnimplementedException();
310 }
311 }
312 }
313
314 // Only used for testing
315 void CopyInFrom(const void* memory) override
316 {
317 switch (this->GetDataType())
318 {
319 case arm_compute::DataType::F32:
320 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
321 this->GetTensor());
322 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000323 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000324 case arm_compute::DataType::QASYMM8:
325 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
326 this->GetTensor());
327 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100328 case arm_compute::DataType::S16:
329 case arm_compute::DataType::QSYMM16:
330 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
331 this->GetTensor());
332 break;
James Conroyd47a0642019-09-17 14:22:06 +0100333 case arm_compute::DataType::S32:
334 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
335 this->GetTensor());
336 break;
David Beck09e2f272018-10-30 11:38:41 +0000337 default:
338 {
339 throw armnn::UnimplementedException();
340 }
341 }
342 }
343
telsoa01c577f2c2018-08-31 09:22:23 +0100344 arm_compute::SubTensor m_Tensor;
345 ITensorHandle* parentHandle = nullptr;
telsoa014fcda012018-03-09 14:13:49 +0000346};
347
Aron Virginas-Tar99836d32019-09-30 16:34:31 +0100348} // namespace armnn