blob: 061117e9a687acfdf7fadbf830a9c2aa413e3c04 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
Derek Lambertic81855f2019-06-13 17:34:19 +01007#include <aclCommon/ArmComputeTensorHandle.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00008#include <aclCommon/ArmComputeTensorUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009
David Beck09e2f272018-10-30 11:38:41 +000010#include <Half.hpp>
11
Jan Eilers3c9e0452020-04-10 13:00:44 +010012#include <armnn/utility/PolymorphicDowncast.hpp>
13
telsoa014fcda012018-03-09 14:13:49 +000014#include <arm_compute/runtime/CL/CLTensor.h>
15#include <arm_compute/runtime/CL/CLSubTensor.h>
telsoa01c577f2c2018-08-31 09:22:23 +010016#include <arm_compute/runtime/IMemoryGroup.h>
Narumol Prangnawarat680f9912019-10-01 11:32:10 +010017#include <arm_compute/runtime/MemoryGroup.h>
telsoa014fcda012018-03-09 14:13:49 +000018#include <arm_compute/core/TensorShape.h>
19#include <arm_compute/core/Coordinates.h>
20
telsoa014fcda012018-03-09 14:13:49 +000021namespace armnn
22{
23
24
Derek Lambertic81855f2019-06-13 17:34:19 +010025class IClTensorHandle : public IAclTensorHandle
telsoa014fcda012018-03-09 14:13:49 +000026{
27public:
28 virtual arm_compute::ICLTensor& GetTensor() = 0;
29 virtual arm_compute::ICLTensor const& GetTensor() const = 0;
telsoa014fcda012018-03-09 14:13:49 +000030 virtual arm_compute::DataType GetDataType() const = 0;
telsoa01c577f2c2018-08-31 09:22:23 +010031 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) = 0;
telsoa014fcda012018-03-09 14:13:49 +000032};
33
34class ClTensorHandle : public IClTensorHandle
35{
36public:
37 ClTensorHandle(const TensorInfo& tensorInfo)
David Monahan66dbf5b2021-03-11 11:34:54 +000038 : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined)),
39 m_Imported(false),
40 m_IsImportEnabled(false)
telsoa014fcda012018-03-09 14:13:49 +000041 {
42 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
43 }
44
David Monahan66dbf5b2021-03-11 11:34:54 +000045 ClTensorHandle(const TensorInfo& tensorInfo,
46 DataLayout dataLayout,
47 MemorySourceFlags importFlags = static_cast<MemorySourceFlags>(MemorySource::Undefined))
48 : m_ImportFlags(importFlags),
49 m_Imported(false),
50 m_IsImportEnabled(false)
Francis Murtagh351d13d2018-09-24 15:01:18 +010051 {
52 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
53 }
54
telsoa014fcda012018-03-09 14:13:49 +000055 arm_compute::CLTensor& GetTensor() override { return m_Tensor; }
56 arm_compute::CLTensor const& GetTensor() const override { return m_Tensor; }
David Monahan66dbf5b2021-03-11 11:34:54 +000057 virtual void Allocate() override
58 {
59 // If we have enabled Importing, don't allocate the tensor
60 if (!m_IsImportEnabled)
61 {
62 armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
63 }
64
65 }
telsoa014fcda012018-03-09 14:13:49 +000066
telsoa01c577f2c2018-08-31 09:22:23 +010067 virtual void Manage() override
68 {
David Monahan66dbf5b2021-03-11 11:34:54 +000069 // If we have enabled Importing, don't manage the tensor
70 if (!m_IsImportEnabled)
71 {
72 assert(m_MemoryGroup != nullptr);
73 m_MemoryGroup->manage(&m_Tensor);
74 }
telsoa01c577f2c2018-08-31 09:22:23 +010075 }
telsoa014fcda012018-03-09 14:13:49 +000076
telsoa01c577f2c2018-08-31 09:22:23 +010077 virtual const void* Map(bool blocking = true) const override
78 {
79 const_cast<arm_compute::CLTensor*>(&m_Tensor)->map(blocking);
80 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
81 }
Matthew Bentham7c1603a2019-06-21 17:22:23 +010082
telsoa01c577f2c2018-08-31 09:22:23 +010083 virtual void Unmap() const override { const_cast<arm_compute::CLTensor*>(&m_Tensor)->unmap(); }
84
telsoa01c577f2c2018-08-31 09:22:23 +010085 virtual ITensorHandle* GetParent() const override { return nullptr; }
telsoa014fcda012018-03-09 14:13:49 +000086
87 virtual arm_compute::DataType GetDataType() const override
88 {
89 return m_Tensor.info()->data_type();
90 }
91
telsoa01c577f2c2018-08-31 09:22:23 +010092 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
93 {
Jan Eilers3c9e0452020-04-10 13:00:44 +010094 m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup);
telsoa01c577f2c2018-08-31 09:22:23 +010095 }
96
97 TensorShape GetStrides() const override
98 {
99 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
100 }
101
102 TensorShape GetShape() const override
103 {
104 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
105 }
David Beck09e2f272018-10-30 11:38:41 +0000106
David Monahan66dbf5b2021-03-11 11:34:54 +0000107 void SetImportFlags(MemorySourceFlags importFlags)
108 {
109 m_ImportFlags = importFlags;
110 }
111
112 MemorySourceFlags GetImportFlags() const override
113 {
114 return m_ImportFlags;
115 }
116
117 void SetImportEnabledFlag(bool importEnabledFlag)
118 {
119 m_IsImportEnabled = importEnabledFlag;
120 }
121
122 virtual bool Import(void* memory, MemorySource source) override
123 {
124 armnn::IgnoreUnused(memory);
125 if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
126 {
127 throw MemoryImportException("ClTensorHandle::Incorrect import flag");
128 }
129 m_Imported = false;
130 return false;
131 }
132
telsoa014fcda012018-03-09 14:13:49 +0000133private:
David Beck09e2f272018-10-30 11:38:41 +0000134 // Only used for testing
135 void CopyOutTo(void* memory) const override
136 {
137 const_cast<armnn::ClTensorHandle*>(this)->Map(true);
138 switch(this->GetDataType())
139 {
140 case arm_compute::DataType::F32:
141 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
142 static_cast<float*>(memory));
143 break;
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +0000144 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000145 case arm_compute::DataType::QASYMM8:
146 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
147 static_cast<uint8_t*>(memory));
148 break;
Sadik Armaganf40d6d42021-04-22 09:12:11 +0100149 case arm_compute::DataType::QSYMM8:
Keith Davisa8565012020-02-14 12:22:40 +0000150 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
151 case arm_compute::DataType::QASYMM8_SIGNED:
152 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
153 static_cast<int8_t*>(memory));
154 break;
David Beck09e2f272018-10-30 11:38:41 +0000155 case arm_compute::DataType::F16:
156 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
157 static_cast<armnn::Half*>(memory));
158 break;
James Conroyd2aa85e2019-07-01 17:12:40 +0100159 case arm_compute::DataType::S16:
160 case arm_compute::DataType::QSYMM16:
161 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
162 static_cast<int16_t*>(memory));
163 break;
James Conroy2dc05722019-09-19 17:00:31 +0100164 case arm_compute::DataType::S32:
165 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
166 static_cast<int32_t*>(memory));
167 break;
David Beck09e2f272018-10-30 11:38:41 +0000168 default:
169 {
170 throw armnn::UnimplementedException();
171 }
172 }
173 const_cast<armnn::ClTensorHandle*>(this)->Unmap();
174 }
175
176 // Only used for testing
177 void CopyInFrom(const void* memory) override
178 {
179 this->Map(true);
180 switch(this->GetDataType())
181 {
182 case arm_compute::DataType::F32:
183 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
184 this->GetTensor());
185 break;
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +0000186 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000187 case arm_compute::DataType::QASYMM8:
188 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
189 this->GetTensor());
190 break;
191 case arm_compute::DataType::F16:
192 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
193 this->GetTensor());
194 break;
James Conroyd2aa85e2019-07-01 17:12:40 +0100195 case arm_compute::DataType::S16:
Sadik Armaganf40d6d42021-04-22 09:12:11 +0100196 case arm_compute::DataType::QSYMM8:
Keith Davisa8565012020-02-14 12:22:40 +0000197 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
198 case arm_compute::DataType::QASYMM8_SIGNED:
199 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
200 this->GetTensor());
201 break;
James Conroyd2aa85e2019-07-01 17:12:40 +0100202 case arm_compute::DataType::QSYMM16:
203 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
204 this->GetTensor());
205 break;
James Conroy2dc05722019-09-19 17:00:31 +0100206 case arm_compute::DataType::S32:
207 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
208 this->GetTensor());
209 break;
David Beck09e2f272018-10-30 11:38:41 +0000210 default:
211 {
212 throw armnn::UnimplementedException();
213 }
214 }
215 this->Unmap();
216 }
217
telsoa014fcda012018-03-09 14:13:49 +0000218 arm_compute::CLTensor m_Tensor;
Narumol Prangnawarat680f9912019-10-01 11:32:10 +0100219 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
David Monahan66dbf5b2021-03-11 11:34:54 +0000220 MemorySourceFlags m_ImportFlags;
221 bool m_Imported;
222 bool m_IsImportEnabled;
telsoa014fcda012018-03-09 14:13:49 +0000223};
224
225class ClSubTensorHandle : public IClTensorHandle
226{
227public:
telsoa01c577f2c2018-08-31 09:22:23 +0100228 ClSubTensorHandle(IClTensorHandle* parent,
229 const arm_compute::TensorShape& shape,
230 const arm_compute::Coordinates& coords)
231 : m_Tensor(&parent->GetTensor(), shape, coords)
telsoa014fcda012018-03-09 14:13:49 +0000232 {
telsoa01c577f2c2018-08-31 09:22:23 +0100233 parentHandle = parent;
telsoa014fcda012018-03-09 14:13:49 +0000234 }
235
236 arm_compute::CLSubTensor& GetTensor() override { return m_Tensor; }
237 arm_compute::CLSubTensor const& GetTensor() const override { return m_Tensor; }
telsoa014fcda012018-03-09 14:13:49 +0000238
telsoa01c577f2c2018-08-31 09:22:23 +0100239 virtual void Allocate() override {}
240 virtual void Manage() override {}
telsoa014fcda012018-03-09 14:13:49 +0000241
telsoa01c577f2c2018-08-31 09:22:23 +0100242 virtual const void* Map(bool blocking = true) const override
243 {
244 const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->map(blocking);
245 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
246 }
247 virtual void Unmap() const override { const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->unmap(); }
248
telsoa01c577f2c2018-08-31 09:22:23 +0100249 virtual ITensorHandle* GetParent() const override { return parentHandle; }
telsoa014fcda012018-03-09 14:13:49 +0000250
251 virtual arm_compute::DataType GetDataType() const override
252 {
253 return m_Tensor.info()->data_type();
254 }
255
telsoa01c577f2c2018-08-31 09:22:23 +0100256 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
257
258 TensorShape GetStrides() const override
259 {
260 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
261 }
262
263 TensorShape GetShape() const override
264 {
265 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
266 }
267
telsoa014fcda012018-03-09 14:13:49 +0000268private:
David Beck09e2f272018-10-30 11:38:41 +0000269 // Only used for testing
270 void CopyOutTo(void* memory) const override
271 {
272 const_cast<ClSubTensorHandle*>(this)->Map(true);
273 switch(this->GetDataType())
274 {
275 case arm_compute::DataType::F32:
276 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
277 static_cast<float*>(memory));
278 break;
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +0000279 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000280 case arm_compute::DataType::QASYMM8:
281 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
282 static_cast<uint8_t*>(memory));
283 break;
284 case arm_compute::DataType::F16:
285 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
286 static_cast<armnn::Half*>(memory));
287 break;
Sadik Armaganf40d6d42021-04-22 09:12:11 +0100288 case arm_compute::DataType::QSYMM8:
Keith Davisa8565012020-02-14 12:22:40 +0000289 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
290 case arm_compute::DataType::QASYMM8_SIGNED:
291 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
292 static_cast<int8_t*>(memory));
293 break;
James Conroyd2aa85e2019-07-01 17:12:40 +0100294 case arm_compute::DataType::S16:
295 case arm_compute::DataType::QSYMM16:
296 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
297 static_cast<int16_t*>(memory));
298 break;
James Conroy2dc05722019-09-19 17:00:31 +0100299 case arm_compute::DataType::S32:
300 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
301 static_cast<int32_t*>(memory));
302 break;
David Beck09e2f272018-10-30 11:38:41 +0000303 default:
304 {
305 throw armnn::UnimplementedException();
306 }
307 }
308 const_cast<ClSubTensorHandle*>(this)->Unmap();
309 }
310
311 // Only used for testing
312 void CopyInFrom(const void* memory) override
313 {
314 this->Map(true);
315 switch(this->GetDataType())
316 {
317 case arm_compute::DataType::F32:
318 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
319 this->GetTensor());
320 break;
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +0000321 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000322 case arm_compute::DataType::QASYMM8:
323 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
324 this->GetTensor());
325 break;
326 case arm_compute::DataType::F16:
327 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
328 this->GetTensor());
329 break;
Sadik Armaganf40d6d42021-04-22 09:12:11 +0100330 case arm_compute::DataType::QSYMM8:
Keith Davisa8565012020-02-14 12:22:40 +0000331 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
332 case arm_compute::DataType::QASYMM8_SIGNED:
333 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
334 this->GetTensor());
335 break;
James Conroyd2aa85e2019-07-01 17:12:40 +0100336 case arm_compute::DataType::S16:
337 case arm_compute::DataType::QSYMM16:
338 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
339 this->GetTensor());
340 break;
James Conroy2dc05722019-09-19 17:00:31 +0100341 case arm_compute::DataType::S32:
342 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
343 this->GetTensor());
344 break;
David Beck09e2f272018-10-30 11:38:41 +0000345 default:
346 {
347 throw armnn::UnimplementedException();
348 }
349 }
350 this->Unmap();
351 }
352
telsoa01c577f2c2018-08-31 09:22:23 +0100353 mutable arm_compute::CLSubTensor m_Tensor;
354 ITensorHandle* parentHandle = nullptr;
telsoa014fcda012018-03-09 14:13:49 +0000355};
356
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +0000357} // namespace armnn