blob: f63f1faa07d7f493c6b46f43e9e05374ef5d9c0f [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
Derek Lambertic81855f2019-06-13 17:34:19 +01007#include <aclCommon/ArmComputeTensorHandle.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00008#include <aclCommon/ArmComputeTensorUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009
David Beck09e2f272018-10-30 11:38:41 +000010#include <Half.hpp>
11
Jan Eilers3c9e0452020-04-10 13:00:44 +010012#include <armnn/utility/PolymorphicDowncast.hpp>
13
telsoa014fcda012018-03-09 14:13:49 +000014#include <arm_compute/runtime/CL/CLTensor.h>
15#include <arm_compute/runtime/CL/CLSubTensor.h>
telsoa01c577f2c2018-08-31 09:22:23 +010016#include <arm_compute/runtime/IMemoryGroup.h>
Narumol Prangnawarat680f9912019-10-01 11:32:10 +010017#include <arm_compute/runtime/MemoryGroup.h>
telsoa014fcda012018-03-09 14:13:49 +000018#include <arm_compute/core/TensorShape.h>
19#include <arm_compute/core/Coordinates.h>
20
Narumol Prangnawarat9ef36142022-01-25 15:15:34 +000021#include <cl/IClTensorHandle.hpp>
22
telsoa014fcda012018-03-09 14:13:49 +000023namespace armnn
24{
25
telsoa014fcda012018-03-09 14:13:49 +000026class ClTensorHandle : public IClTensorHandle
27{
28public:
29 ClTensorHandle(const TensorInfo& tensorInfo)
David Monahan66dbf5b2021-03-11 11:34:54 +000030 : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined)),
31 m_Imported(false),
32 m_IsImportEnabled(false)
telsoa014fcda012018-03-09 14:13:49 +000033 {
34 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
35 }
36
David Monahan66dbf5b2021-03-11 11:34:54 +000037 ClTensorHandle(const TensorInfo& tensorInfo,
38 DataLayout dataLayout,
39 MemorySourceFlags importFlags = static_cast<MemorySourceFlags>(MemorySource::Undefined))
40 : m_ImportFlags(importFlags),
41 m_Imported(false),
42 m_IsImportEnabled(false)
Francis Murtagh351d13d2018-09-24 15:01:18 +010043 {
44 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
45 }
46
telsoa014fcda012018-03-09 14:13:49 +000047 arm_compute::CLTensor& GetTensor() override { return m_Tensor; }
48 arm_compute::CLTensor const& GetTensor() const override { return m_Tensor; }
David Monahan66dbf5b2021-03-11 11:34:54 +000049 virtual void Allocate() override
50 {
51 // If we have enabled Importing, don't allocate the tensor
David Monahane4a41dc2021-04-14 16:55:36 +010052 if (m_IsImportEnabled)
53 {
54 throw MemoryImportException("ClTensorHandle::Attempting to allocate memory when importing");
55 }
56 else
David Monahan66dbf5b2021-03-11 11:34:54 +000057 {
58 armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
59 }
60
61 }
telsoa014fcda012018-03-09 14:13:49 +000062
telsoa01c577f2c2018-08-31 09:22:23 +010063 virtual void Manage() override
64 {
David Monahan66dbf5b2021-03-11 11:34:54 +000065 // If we have enabled Importing, don't manage the tensor
David Monahane4a41dc2021-04-14 16:55:36 +010066 if (m_IsImportEnabled)
67 {
68 throw MemoryImportException("ClTensorHandle::Attempting to manage memory when importing");
69 }
70 else
David Monahan66dbf5b2021-03-11 11:34:54 +000071 {
72 assert(m_MemoryGroup != nullptr);
73 m_MemoryGroup->manage(&m_Tensor);
74 }
telsoa01c577f2c2018-08-31 09:22:23 +010075 }
telsoa014fcda012018-03-09 14:13:49 +000076
telsoa01c577f2c2018-08-31 09:22:23 +010077 virtual const void* Map(bool blocking = true) const override
78 {
79 const_cast<arm_compute::CLTensor*>(&m_Tensor)->map(blocking);
80 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
81 }
Matthew Bentham7c1603a2019-06-21 17:22:23 +010082
telsoa01c577f2c2018-08-31 09:22:23 +010083 virtual void Unmap() const override { const_cast<arm_compute::CLTensor*>(&m_Tensor)->unmap(); }
84
telsoa01c577f2c2018-08-31 09:22:23 +010085 virtual ITensorHandle* GetParent() const override { return nullptr; }
telsoa014fcda012018-03-09 14:13:49 +000086
87 virtual arm_compute::DataType GetDataType() const override
88 {
89 return m_Tensor.info()->data_type();
90 }
91
telsoa01c577f2c2018-08-31 09:22:23 +010092 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
93 {
Jan Eilers3c9e0452020-04-10 13:00:44 +010094 m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup);
telsoa01c577f2c2018-08-31 09:22:23 +010095 }
96
97 TensorShape GetStrides() const override
98 {
99 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
100 }
101
102 TensorShape GetShape() const override
103 {
104 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
105 }
David Beck09e2f272018-10-30 11:38:41 +0000106
David Monahan66dbf5b2021-03-11 11:34:54 +0000107 void SetImportFlags(MemorySourceFlags importFlags)
108 {
109 m_ImportFlags = importFlags;
110 }
111
112 MemorySourceFlags GetImportFlags() const override
113 {
114 return m_ImportFlags;
115 }
116
117 void SetImportEnabledFlag(bool importEnabledFlag)
118 {
119 m_IsImportEnabled = importEnabledFlag;
120 }
121
122 virtual bool Import(void* memory, MemorySource source) override
123 {
124 armnn::IgnoreUnused(memory);
125 if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
126 {
127 throw MemoryImportException("ClTensorHandle::Incorrect import flag");
128 }
129 m_Imported = false;
130 return false;
131 }
132
Nikhil Raj60ab9762022-01-13 09:34:44 +0000133 virtual bool CanBeImported(void* memory, MemorySource source) override
134 {
135 // This TensorHandle can never import.
136 armnn::IgnoreUnused(memory, source);
137 return false;
138 }
139
telsoa014fcda012018-03-09 14:13:49 +0000140private:
David Beck09e2f272018-10-30 11:38:41 +0000141 // Only used for testing
142 void CopyOutTo(void* memory) const override
143 {
144 const_cast<armnn::ClTensorHandle*>(this)->Map(true);
145 switch(this->GetDataType())
146 {
147 case arm_compute::DataType::F32:
148 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
149 static_cast<float*>(memory));
150 break;
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +0000151 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000152 case arm_compute::DataType::QASYMM8:
153 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
154 static_cast<uint8_t*>(memory));
155 break;
Sadik Armaganf40d6d42021-04-22 09:12:11 +0100156 case arm_compute::DataType::QSYMM8:
Keith Davisa8565012020-02-14 12:22:40 +0000157 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
158 case arm_compute::DataType::QASYMM8_SIGNED:
159 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
160 static_cast<int8_t*>(memory));
161 break;
David Beck09e2f272018-10-30 11:38:41 +0000162 case arm_compute::DataType::F16:
163 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
164 static_cast<armnn::Half*>(memory));
165 break;
James Conroyd2aa85e2019-07-01 17:12:40 +0100166 case arm_compute::DataType::S16:
167 case arm_compute::DataType::QSYMM16:
168 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
169 static_cast<int16_t*>(memory));
170 break;
James Conroy2dc05722019-09-19 17:00:31 +0100171 case arm_compute::DataType::S32:
172 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
173 static_cast<int32_t*>(memory));
174 break;
David Beck09e2f272018-10-30 11:38:41 +0000175 default:
176 {
177 throw armnn::UnimplementedException();
178 }
179 }
180 const_cast<armnn::ClTensorHandle*>(this)->Unmap();
181 }
182
183 // Only used for testing
184 void CopyInFrom(const void* memory) override
185 {
186 this->Map(true);
187 switch(this->GetDataType())
188 {
189 case arm_compute::DataType::F32:
190 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
191 this->GetTensor());
192 break;
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +0000193 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000194 case arm_compute::DataType::QASYMM8:
195 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
196 this->GetTensor());
197 break;
198 case arm_compute::DataType::F16:
199 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
200 this->GetTensor());
201 break;
James Conroyd2aa85e2019-07-01 17:12:40 +0100202 case arm_compute::DataType::S16:
Sadik Armaganf40d6d42021-04-22 09:12:11 +0100203 case arm_compute::DataType::QSYMM8:
Keith Davisa8565012020-02-14 12:22:40 +0000204 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
205 case arm_compute::DataType::QASYMM8_SIGNED:
206 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
207 this->GetTensor());
208 break;
James Conroyd2aa85e2019-07-01 17:12:40 +0100209 case arm_compute::DataType::QSYMM16:
210 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
211 this->GetTensor());
212 break;
James Conroy2dc05722019-09-19 17:00:31 +0100213 case arm_compute::DataType::S32:
214 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
215 this->GetTensor());
216 break;
David Beck09e2f272018-10-30 11:38:41 +0000217 default:
218 {
219 throw armnn::UnimplementedException();
220 }
221 }
222 this->Unmap();
223 }
224
telsoa014fcda012018-03-09 14:13:49 +0000225 arm_compute::CLTensor m_Tensor;
Narumol Prangnawarat680f9912019-10-01 11:32:10 +0100226 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
David Monahan66dbf5b2021-03-11 11:34:54 +0000227 MemorySourceFlags m_ImportFlags;
228 bool m_Imported;
229 bool m_IsImportEnabled;
telsoa014fcda012018-03-09 14:13:49 +0000230};
231
232class ClSubTensorHandle : public IClTensorHandle
233{
234public:
telsoa01c577f2c2018-08-31 09:22:23 +0100235 ClSubTensorHandle(IClTensorHandle* parent,
236 const arm_compute::TensorShape& shape,
237 const arm_compute::Coordinates& coords)
238 : m_Tensor(&parent->GetTensor(), shape, coords)
telsoa014fcda012018-03-09 14:13:49 +0000239 {
telsoa01c577f2c2018-08-31 09:22:23 +0100240 parentHandle = parent;
telsoa014fcda012018-03-09 14:13:49 +0000241 }
242
243 arm_compute::CLSubTensor& GetTensor() override { return m_Tensor; }
244 arm_compute::CLSubTensor const& GetTensor() const override { return m_Tensor; }
telsoa014fcda012018-03-09 14:13:49 +0000245
telsoa01c577f2c2018-08-31 09:22:23 +0100246 virtual void Allocate() override {}
247 virtual void Manage() override {}
telsoa014fcda012018-03-09 14:13:49 +0000248
telsoa01c577f2c2018-08-31 09:22:23 +0100249 virtual const void* Map(bool blocking = true) const override
250 {
251 const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->map(blocking);
252 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
253 }
254 virtual void Unmap() const override { const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->unmap(); }
255
telsoa01c577f2c2018-08-31 09:22:23 +0100256 virtual ITensorHandle* GetParent() const override { return parentHandle; }
telsoa014fcda012018-03-09 14:13:49 +0000257
258 virtual arm_compute::DataType GetDataType() const override
259 {
260 return m_Tensor.info()->data_type();
261 }
262
telsoa01c577f2c2018-08-31 09:22:23 +0100263 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
264
265 TensorShape GetStrides() const override
266 {
267 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
268 }
269
270 TensorShape GetShape() const override
271 {
272 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
273 }
274
telsoa014fcda012018-03-09 14:13:49 +0000275private:
David Beck09e2f272018-10-30 11:38:41 +0000276 // Only used for testing
277 void CopyOutTo(void* memory) const override
278 {
279 const_cast<ClSubTensorHandle*>(this)->Map(true);
280 switch(this->GetDataType())
281 {
282 case arm_compute::DataType::F32:
283 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
284 static_cast<float*>(memory));
285 break;
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +0000286 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000287 case arm_compute::DataType::QASYMM8:
288 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
289 static_cast<uint8_t*>(memory));
290 break;
291 case arm_compute::DataType::F16:
292 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
293 static_cast<armnn::Half*>(memory));
294 break;
Sadik Armaganf40d6d42021-04-22 09:12:11 +0100295 case arm_compute::DataType::QSYMM8:
Keith Davisa8565012020-02-14 12:22:40 +0000296 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
297 case arm_compute::DataType::QASYMM8_SIGNED:
298 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
299 static_cast<int8_t*>(memory));
300 break;
James Conroyd2aa85e2019-07-01 17:12:40 +0100301 case arm_compute::DataType::S16:
302 case arm_compute::DataType::QSYMM16:
303 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
304 static_cast<int16_t*>(memory));
305 break;
James Conroy2dc05722019-09-19 17:00:31 +0100306 case arm_compute::DataType::S32:
307 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
308 static_cast<int32_t*>(memory));
309 break;
David Beck09e2f272018-10-30 11:38:41 +0000310 default:
311 {
312 throw armnn::UnimplementedException();
313 }
314 }
315 const_cast<ClSubTensorHandle*>(this)->Unmap();
316 }
317
318 // Only used for testing
319 void CopyInFrom(const void* memory) override
320 {
321 this->Map(true);
322 switch(this->GetDataType())
323 {
324 case arm_compute::DataType::F32:
325 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
326 this->GetTensor());
327 break;
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +0000328 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000329 case arm_compute::DataType::QASYMM8:
330 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
331 this->GetTensor());
332 break;
333 case arm_compute::DataType::F16:
334 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
335 this->GetTensor());
336 break;
Sadik Armaganf40d6d42021-04-22 09:12:11 +0100337 case arm_compute::DataType::QSYMM8:
Keith Davisa8565012020-02-14 12:22:40 +0000338 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
339 case arm_compute::DataType::QASYMM8_SIGNED:
340 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
341 this->GetTensor());
342 break;
James Conroyd2aa85e2019-07-01 17:12:40 +0100343 case arm_compute::DataType::S16:
344 case arm_compute::DataType::QSYMM16:
345 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
346 this->GetTensor());
347 break;
James Conroy2dc05722019-09-19 17:00:31 +0100348 case arm_compute::DataType::S32:
349 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
350 this->GetTensor());
351 break;
David Beck09e2f272018-10-30 11:38:41 +0000352 default:
353 {
354 throw armnn::UnimplementedException();
355 }
356 }
357 this->Unmap();
358 }
359
telsoa01c577f2c2018-08-31 09:22:23 +0100360 mutable arm_compute::CLSubTensor m_Tensor;
361 ITensorHandle* parentHandle = nullptr;
telsoa014fcda012018-03-09 14:13:49 +0000362};
363
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +0000364} // namespace armnn