blob: 5720d2cf11c672c35c376089dd804034f24b6a99 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
Derek Lambertic81855f2019-06-13 17:34:19 +01007#include <aclCommon/ArmComputeTensorHandle.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00008#include <aclCommon/ArmComputeTensorUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009
David Beck09e2f272018-10-30 11:38:41 +000010#include <Half.hpp>
11
Jan Eilers3c9e0452020-04-10 13:00:44 +010012#include <armnn/utility/PolymorphicDowncast.hpp>
13
telsoa014fcda012018-03-09 14:13:49 +000014#include <arm_compute/runtime/CL/CLTensor.h>
15#include <arm_compute/runtime/CL/CLSubTensor.h>
telsoa01c577f2c2018-08-31 09:22:23 +010016#include <arm_compute/runtime/IMemoryGroup.h>
Narumol Prangnawarat680f9912019-10-01 11:32:10 +010017#include <arm_compute/runtime/MemoryGroup.h>
telsoa014fcda012018-03-09 14:13:49 +000018#include <arm_compute/core/TensorShape.h>
19#include <arm_compute/core/Coordinates.h>
20
telsoa014fcda012018-03-09 14:13:49 +000021namespace armnn
22{
23
24
Derek Lambertic81855f2019-06-13 17:34:19 +010025class IClTensorHandle : public IAclTensorHandle
telsoa014fcda012018-03-09 14:13:49 +000026{
27public:
28 virtual arm_compute::ICLTensor& GetTensor() = 0;
29 virtual arm_compute::ICLTensor const& GetTensor() const = 0;
telsoa014fcda012018-03-09 14:13:49 +000030 virtual arm_compute::DataType GetDataType() const = 0;
telsoa01c577f2c2018-08-31 09:22:23 +010031 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) = 0;
telsoa014fcda012018-03-09 14:13:49 +000032};
33
34class ClTensorHandle : public IClTensorHandle
35{
36public:
37 ClTensorHandle(const TensorInfo& tensorInfo)
David Monahan66dbf5b2021-03-11 11:34:54 +000038 : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined)),
39 m_Imported(false),
40 m_IsImportEnabled(false)
telsoa014fcda012018-03-09 14:13:49 +000041 {
42 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
43 }
44
David Monahan66dbf5b2021-03-11 11:34:54 +000045 ClTensorHandle(const TensorInfo& tensorInfo,
46 DataLayout dataLayout,
47 MemorySourceFlags importFlags = static_cast<MemorySourceFlags>(MemorySource::Undefined))
48 : m_ImportFlags(importFlags),
49 m_Imported(false),
50 m_IsImportEnabled(false)
Francis Murtagh351d13d2018-09-24 15:01:18 +010051 {
52 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
53 }
54
telsoa014fcda012018-03-09 14:13:49 +000055 arm_compute::CLTensor& GetTensor() override { return m_Tensor; }
56 arm_compute::CLTensor const& GetTensor() const override { return m_Tensor; }
David Monahan66dbf5b2021-03-11 11:34:54 +000057 virtual void Allocate() override
58 {
59 // If we have enabled Importing, don't allocate the tensor
David Monahane4a41dc2021-04-14 16:55:36 +010060 if (m_IsImportEnabled)
61 {
62 throw MemoryImportException("ClTensorHandle::Attempting to allocate memory when importing");
63 }
64 else
David Monahan66dbf5b2021-03-11 11:34:54 +000065 {
66 armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
67 }
68
69 }
telsoa014fcda012018-03-09 14:13:49 +000070
telsoa01c577f2c2018-08-31 09:22:23 +010071 virtual void Manage() override
72 {
David Monahan66dbf5b2021-03-11 11:34:54 +000073 // If we have enabled Importing, don't manage the tensor
David Monahane4a41dc2021-04-14 16:55:36 +010074 if (m_IsImportEnabled)
75 {
76 throw MemoryImportException("ClTensorHandle::Attempting to manage memory when importing");
77 }
78 else
David Monahan66dbf5b2021-03-11 11:34:54 +000079 {
80 assert(m_MemoryGroup != nullptr);
81 m_MemoryGroup->manage(&m_Tensor);
82 }
telsoa01c577f2c2018-08-31 09:22:23 +010083 }
telsoa014fcda012018-03-09 14:13:49 +000084
telsoa01c577f2c2018-08-31 09:22:23 +010085 virtual const void* Map(bool blocking = true) const override
86 {
87 const_cast<arm_compute::CLTensor*>(&m_Tensor)->map(blocking);
88 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
89 }
Matthew Bentham7c1603a2019-06-21 17:22:23 +010090
telsoa01c577f2c2018-08-31 09:22:23 +010091 virtual void Unmap() const override { const_cast<arm_compute::CLTensor*>(&m_Tensor)->unmap(); }
92
telsoa01c577f2c2018-08-31 09:22:23 +010093 virtual ITensorHandle* GetParent() const override { return nullptr; }
telsoa014fcda012018-03-09 14:13:49 +000094
95 virtual arm_compute::DataType GetDataType() const override
96 {
97 return m_Tensor.info()->data_type();
98 }
99
telsoa01c577f2c2018-08-31 09:22:23 +0100100 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
101 {
Jan Eilers3c9e0452020-04-10 13:00:44 +0100102 m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup);
telsoa01c577f2c2018-08-31 09:22:23 +0100103 }
104
105 TensorShape GetStrides() const override
106 {
107 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
108 }
109
110 TensorShape GetShape() const override
111 {
112 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
113 }
David Beck09e2f272018-10-30 11:38:41 +0000114
David Monahan66dbf5b2021-03-11 11:34:54 +0000115 void SetImportFlags(MemorySourceFlags importFlags)
116 {
117 m_ImportFlags = importFlags;
118 }
119
120 MemorySourceFlags GetImportFlags() const override
121 {
122 return m_ImportFlags;
123 }
124
125 void SetImportEnabledFlag(bool importEnabledFlag)
126 {
127 m_IsImportEnabled = importEnabledFlag;
128 }
129
130 virtual bool Import(void* memory, MemorySource source) override
131 {
132 armnn::IgnoreUnused(memory);
133 if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
134 {
135 throw MemoryImportException("ClTensorHandle::Incorrect import flag");
136 }
137 m_Imported = false;
138 return false;
139 }
140
telsoa014fcda012018-03-09 14:13:49 +0000141private:
David Beck09e2f272018-10-30 11:38:41 +0000142 // Only used for testing
143 void CopyOutTo(void* memory) const override
144 {
145 const_cast<armnn::ClTensorHandle*>(this)->Map(true);
146 switch(this->GetDataType())
147 {
148 case arm_compute::DataType::F32:
149 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
150 static_cast<float*>(memory));
151 break;
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +0000152 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000153 case arm_compute::DataType::QASYMM8:
154 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
155 static_cast<uint8_t*>(memory));
156 break;
Sadik Armaganf40d6d42021-04-22 09:12:11 +0100157 case arm_compute::DataType::QSYMM8:
Keith Davisa8565012020-02-14 12:22:40 +0000158 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
159 case arm_compute::DataType::QASYMM8_SIGNED:
160 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
161 static_cast<int8_t*>(memory));
162 break;
David Beck09e2f272018-10-30 11:38:41 +0000163 case arm_compute::DataType::F16:
164 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
165 static_cast<armnn::Half*>(memory));
166 break;
James Conroyd2aa85e2019-07-01 17:12:40 +0100167 case arm_compute::DataType::S16:
168 case arm_compute::DataType::QSYMM16:
169 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
170 static_cast<int16_t*>(memory));
171 break;
James Conroy2dc05722019-09-19 17:00:31 +0100172 case arm_compute::DataType::S32:
173 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
174 static_cast<int32_t*>(memory));
175 break;
David Beck09e2f272018-10-30 11:38:41 +0000176 default:
177 {
178 throw armnn::UnimplementedException();
179 }
180 }
181 const_cast<armnn::ClTensorHandle*>(this)->Unmap();
182 }
183
184 // Only used for testing
185 void CopyInFrom(const void* memory) override
186 {
187 this->Map(true);
188 switch(this->GetDataType())
189 {
190 case arm_compute::DataType::F32:
191 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
192 this->GetTensor());
193 break;
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +0000194 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000195 case arm_compute::DataType::QASYMM8:
196 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
197 this->GetTensor());
198 break;
199 case arm_compute::DataType::F16:
200 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
201 this->GetTensor());
202 break;
James Conroyd2aa85e2019-07-01 17:12:40 +0100203 case arm_compute::DataType::S16:
Sadik Armaganf40d6d42021-04-22 09:12:11 +0100204 case arm_compute::DataType::QSYMM8:
Keith Davisa8565012020-02-14 12:22:40 +0000205 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
206 case arm_compute::DataType::QASYMM8_SIGNED:
207 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
208 this->GetTensor());
209 break;
James Conroyd2aa85e2019-07-01 17:12:40 +0100210 case arm_compute::DataType::QSYMM16:
211 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
212 this->GetTensor());
213 break;
James Conroy2dc05722019-09-19 17:00:31 +0100214 case arm_compute::DataType::S32:
215 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
216 this->GetTensor());
217 break;
David Beck09e2f272018-10-30 11:38:41 +0000218 default:
219 {
220 throw armnn::UnimplementedException();
221 }
222 }
223 this->Unmap();
224 }
225
telsoa014fcda012018-03-09 14:13:49 +0000226 arm_compute::CLTensor m_Tensor;
Narumol Prangnawarat680f9912019-10-01 11:32:10 +0100227 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
David Monahan66dbf5b2021-03-11 11:34:54 +0000228 MemorySourceFlags m_ImportFlags;
229 bool m_Imported;
230 bool m_IsImportEnabled;
telsoa014fcda012018-03-09 14:13:49 +0000231};
232
233class ClSubTensorHandle : public IClTensorHandle
234{
235public:
telsoa01c577f2c2018-08-31 09:22:23 +0100236 ClSubTensorHandle(IClTensorHandle* parent,
237 const arm_compute::TensorShape& shape,
238 const arm_compute::Coordinates& coords)
239 : m_Tensor(&parent->GetTensor(), shape, coords)
telsoa014fcda012018-03-09 14:13:49 +0000240 {
telsoa01c577f2c2018-08-31 09:22:23 +0100241 parentHandle = parent;
telsoa014fcda012018-03-09 14:13:49 +0000242 }
243
244 arm_compute::CLSubTensor& GetTensor() override { return m_Tensor; }
245 arm_compute::CLSubTensor const& GetTensor() const override { return m_Tensor; }
telsoa014fcda012018-03-09 14:13:49 +0000246
telsoa01c577f2c2018-08-31 09:22:23 +0100247 virtual void Allocate() override {}
248 virtual void Manage() override {}
telsoa014fcda012018-03-09 14:13:49 +0000249
telsoa01c577f2c2018-08-31 09:22:23 +0100250 virtual const void* Map(bool blocking = true) const override
251 {
252 const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->map(blocking);
253 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
254 }
255 virtual void Unmap() const override { const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->unmap(); }
256
telsoa01c577f2c2018-08-31 09:22:23 +0100257 virtual ITensorHandle* GetParent() const override { return parentHandle; }
telsoa014fcda012018-03-09 14:13:49 +0000258
259 virtual arm_compute::DataType GetDataType() const override
260 {
261 return m_Tensor.info()->data_type();
262 }
263
telsoa01c577f2c2018-08-31 09:22:23 +0100264 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
265
266 TensorShape GetStrides() const override
267 {
268 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
269 }
270
271 TensorShape GetShape() const override
272 {
273 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
274 }
275
telsoa014fcda012018-03-09 14:13:49 +0000276private:
David Beck09e2f272018-10-30 11:38:41 +0000277 // Only used for testing
278 void CopyOutTo(void* memory) const override
279 {
280 const_cast<ClSubTensorHandle*>(this)->Map(true);
281 switch(this->GetDataType())
282 {
283 case arm_compute::DataType::F32:
284 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
285 static_cast<float*>(memory));
286 break;
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +0000287 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000288 case arm_compute::DataType::QASYMM8:
289 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
290 static_cast<uint8_t*>(memory));
291 break;
292 case arm_compute::DataType::F16:
293 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
294 static_cast<armnn::Half*>(memory));
295 break;
Sadik Armaganf40d6d42021-04-22 09:12:11 +0100296 case arm_compute::DataType::QSYMM8:
Keith Davisa8565012020-02-14 12:22:40 +0000297 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
298 case arm_compute::DataType::QASYMM8_SIGNED:
299 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
300 static_cast<int8_t*>(memory));
301 break;
James Conroyd2aa85e2019-07-01 17:12:40 +0100302 case arm_compute::DataType::S16:
303 case arm_compute::DataType::QSYMM16:
304 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
305 static_cast<int16_t*>(memory));
306 break;
James Conroy2dc05722019-09-19 17:00:31 +0100307 case arm_compute::DataType::S32:
308 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
309 static_cast<int32_t*>(memory));
310 break;
David Beck09e2f272018-10-30 11:38:41 +0000311 default:
312 {
313 throw armnn::UnimplementedException();
314 }
315 }
316 const_cast<ClSubTensorHandle*>(this)->Unmap();
317 }
318
319 // Only used for testing
320 void CopyInFrom(const void* memory) override
321 {
322 this->Map(true);
323 switch(this->GetDataType())
324 {
325 case arm_compute::DataType::F32:
326 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
327 this->GetTensor());
328 break;
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +0000329 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000330 case arm_compute::DataType::QASYMM8:
331 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
332 this->GetTensor());
333 break;
334 case arm_compute::DataType::F16:
335 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
336 this->GetTensor());
337 break;
Sadik Armaganf40d6d42021-04-22 09:12:11 +0100338 case arm_compute::DataType::QSYMM8:
Keith Davisa8565012020-02-14 12:22:40 +0000339 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
340 case arm_compute::DataType::QASYMM8_SIGNED:
341 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
342 this->GetTensor());
343 break;
James Conroyd2aa85e2019-07-01 17:12:40 +0100344 case arm_compute::DataType::S16:
345 case arm_compute::DataType::QSYMM16:
346 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
347 this->GetTensor());
348 break;
James Conroy2dc05722019-09-19 17:00:31 +0100349 case arm_compute::DataType::S32:
350 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
351 this->GetTensor());
352 break;
David Beck09e2f272018-10-30 11:38:41 +0000353 default:
354 {
355 throw armnn::UnimplementedException();
356 }
357 }
358 this->Unmap();
359 }
360
telsoa01c577f2c2018-08-31 09:22:23 +0100361 mutable arm_compute::CLSubTensor m_Tensor;
362 ITensorHandle* parentHandle = nullptr;
telsoa014fcda012018-03-09 14:13:49 +0000363};
364
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +0000365} // namespace armnn