blob: 0302ef5790cffc9efd73c233f2a7b4e64a195b51 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
Derek Lambertic81855f2019-06-13 17:34:19 +01007#include <aclCommon/ArmComputeTensorHandle.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00008#include <aclCommon/ArmComputeTensorUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009
David Beck09e2f272018-10-30 11:38:41 +000010#include <Half.hpp>
11
Jan Eilers3c9e0452020-04-10 13:00:44 +010012#include <armnn/utility/PolymorphicDowncast.hpp>
13
telsoa014fcda012018-03-09 14:13:49 +000014#include <arm_compute/runtime/CL/CLTensor.h>
15#include <arm_compute/runtime/CL/CLSubTensor.h>
telsoa01c577f2c2018-08-31 09:22:23 +010016#include <arm_compute/runtime/IMemoryGroup.h>
Narumol Prangnawarat680f9912019-10-01 11:32:10 +010017#include <arm_compute/runtime/MemoryGroup.h>
telsoa014fcda012018-03-09 14:13:49 +000018#include <arm_compute/core/TensorShape.h>
19#include <arm_compute/core/Coordinates.h>
20
telsoa014fcda012018-03-09 14:13:49 +000021namespace armnn
22{
23
24
Derek Lambertic81855f2019-06-13 17:34:19 +010025class IClTensorHandle : public IAclTensorHandle
telsoa014fcda012018-03-09 14:13:49 +000026{
27public:
28 virtual arm_compute::ICLTensor& GetTensor() = 0;
29 virtual arm_compute::ICLTensor const& GetTensor() const = 0;
telsoa014fcda012018-03-09 14:13:49 +000030 virtual arm_compute::DataType GetDataType() const = 0;
telsoa01c577f2c2018-08-31 09:22:23 +010031 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) = 0;
telsoa014fcda012018-03-09 14:13:49 +000032};
33
34class ClTensorHandle : public IClTensorHandle
35{
36public:
37 ClTensorHandle(const TensorInfo& tensorInfo)
38 {
39 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
40 }
41
Francis Murtagh351d13d2018-09-24 15:01:18 +010042 ClTensorHandle(const TensorInfo& tensorInfo, DataLayout dataLayout)
43 {
44 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
45 }
46
telsoa014fcda012018-03-09 14:13:49 +000047 arm_compute::CLTensor& GetTensor() override { return m_Tensor; }
48 arm_compute::CLTensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +010049 virtual void Allocate() override {armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);}
telsoa014fcda012018-03-09 14:13:49 +000050
telsoa01c577f2c2018-08-31 09:22:23 +010051 virtual void Manage() override
52 {
53 assert(m_MemoryGroup != nullptr);
54 m_MemoryGroup->manage(&m_Tensor);
55 }
telsoa014fcda012018-03-09 14:13:49 +000056
telsoa01c577f2c2018-08-31 09:22:23 +010057 virtual const void* Map(bool blocking = true) const override
58 {
59 const_cast<arm_compute::CLTensor*>(&m_Tensor)->map(blocking);
60 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
61 }
Matthew Bentham7c1603a2019-06-21 17:22:23 +010062
telsoa01c577f2c2018-08-31 09:22:23 +010063 virtual void Unmap() const override { const_cast<arm_compute::CLTensor*>(&m_Tensor)->unmap(); }
64
telsoa01c577f2c2018-08-31 09:22:23 +010065 virtual ITensorHandle* GetParent() const override { return nullptr; }
telsoa014fcda012018-03-09 14:13:49 +000066
67 virtual arm_compute::DataType GetDataType() const override
68 {
69 return m_Tensor.info()->data_type();
70 }
71
telsoa01c577f2c2018-08-31 09:22:23 +010072 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
73 {
Jan Eilers3c9e0452020-04-10 13:00:44 +010074 m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup);
telsoa01c577f2c2018-08-31 09:22:23 +010075 }
76
77 TensorShape GetStrides() const override
78 {
79 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
80 }
81
82 TensorShape GetShape() const override
83 {
84 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
85 }
David Beck09e2f272018-10-30 11:38:41 +000086
telsoa014fcda012018-03-09 14:13:49 +000087private:
David Beck09e2f272018-10-30 11:38:41 +000088 // Only used for testing
89 void CopyOutTo(void* memory) const override
90 {
91 const_cast<armnn::ClTensorHandle*>(this)->Map(true);
92 switch(this->GetDataType())
93 {
94 case arm_compute::DataType::F32:
95 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
96 static_cast<float*>(memory));
97 break;
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +000098 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +000099 case arm_compute::DataType::QASYMM8:
100 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
101 static_cast<uint8_t*>(memory));
102 break;
Keith Davisa8565012020-02-14 12:22:40 +0000103 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
104 case arm_compute::DataType::QASYMM8_SIGNED:
105 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
106 static_cast<int8_t*>(memory));
107 break;
David Beck09e2f272018-10-30 11:38:41 +0000108 case arm_compute::DataType::F16:
109 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
110 static_cast<armnn::Half*>(memory));
111 break;
James Conroyd2aa85e2019-07-01 17:12:40 +0100112 case arm_compute::DataType::S16:
113 case arm_compute::DataType::QSYMM16:
114 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
115 static_cast<int16_t*>(memory));
116 break;
James Conroy2dc05722019-09-19 17:00:31 +0100117 case arm_compute::DataType::S32:
118 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
119 static_cast<int32_t*>(memory));
120 break;
David Beck09e2f272018-10-30 11:38:41 +0000121 default:
122 {
123 throw armnn::UnimplementedException();
124 }
125 }
126 const_cast<armnn::ClTensorHandle*>(this)->Unmap();
127 }
128
129 // Only used for testing
130 void CopyInFrom(const void* memory) override
131 {
132 this->Map(true);
133 switch(this->GetDataType())
134 {
135 case arm_compute::DataType::F32:
136 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
137 this->GetTensor());
138 break;
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +0000139 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000140 case arm_compute::DataType::QASYMM8:
141 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
142 this->GetTensor());
143 break;
144 case arm_compute::DataType::F16:
145 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
146 this->GetTensor());
147 break;
James Conroyd2aa85e2019-07-01 17:12:40 +0100148 case arm_compute::DataType::S16:
Keith Davisa8565012020-02-14 12:22:40 +0000149 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
150 case arm_compute::DataType::QASYMM8_SIGNED:
151 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
152 this->GetTensor());
153 break;
James Conroyd2aa85e2019-07-01 17:12:40 +0100154 case arm_compute::DataType::QSYMM16:
155 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
156 this->GetTensor());
157 break;
James Conroy2dc05722019-09-19 17:00:31 +0100158 case arm_compute::DataType::S32:
159 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
160 this->GetTensor());
161 break;
David Beck09e2f272018-10-30 11:38:41 +0000162 default:
163 {
164 throw armnn::UnimplementedException();
165 }
166 }
167 this->Unmap();
168 }
169
telsoa014fcda012018-03-09 14:13:49 +0000170 arm_compute::CLTensor m_Tensor;
Narumol Prangnawarat680f9912019-10-01 11:32:10 +0100171 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
telsoa014fcda012018-03-09 14:13:49 +0000172};
173
174class ClSubTensorHandle : public IClTensorHandle
175{
176public:
telsoa01c577f2c2018-08-31 09:22:23 +0100177 ClSubTensorHandle(IClTensorHandle* parent,
178 const arm_compute::TensorShape& shape,
179 const arm_compute::Coordinates& coords)
180 : m_Tensor(&parent->GetTensor(), shape, coords)
telsoa014fcda012018-03-09 14:13:49 +0000181 {
telsoa01c577f2c2018-08-31 09:22:23 +0100182 parentHandle = parent;
telsoa014fcda012018-03-09 14:13:49 +0000183 }
184
185 arm_compute::CLSubTensor& GetTensor() override { return m_Tensor; }
186 arm_compute::CLSubTensor const& GetTensor() const override { return m_Tensor; }
telsoa014fcda012018-03-09 14:13:49 +0000187
telsoa01c577f2c2018-08-31 09:22:23 +0100188 virtual void Allocate() override {}
189 virtual void Manage() override {}
telsoa014fcda012018-03-09 14:13:49 +0000190
telsoa01c577f2c2018-08-31 09:22:23 +0100191 virtual const void* Map(bool blocking = true) const override
192 {
193 const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->map(blocking);
194 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
195 }
196 virtual void Unmap() const override { const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->unmap(); }
197
telsoa01c577f2c2018-08-31 09:22:23 +0100198 virtual ITensorHandle* GetParent() const override { return parentHandle; }
telsoa014fcda012018-03-09 14:13:49 +0000199
200 virtual arm_compute::DataType GetDataType() const override
201 {
202 return m_Tensor.info()->data_type();
203 }
204
telsoa01c577f2c2018-08-31 09:22:23 +0100205 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
206
207 TensorShape GetStrides() const override
208 {
209 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
210 }
211
212 TensorShape GetShape() const override
213 {
214 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
215 }
216
telsoa014fcda012018-03-09 14:13:49 +0000217private:
David Beck09e2f272018-10-30 11:38:41 +0000218 // Only used for testing
219 void CopyOutTo(void* memory) const override
220 {
221 const_cast<ClSubTensorHandle*>(this)->Map(true);
222 switch(this->GetDataType())
223 {
224 case arm_compute::DataType::F32:
225 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
226 static_cast<float*>(memory));
227 break;
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +0000228 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000229 case arm_compute::DataType::QASYMM8:
230 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
231 static_cast<uint8_t*>(memory));
232 break;
233 case arm_compute::DataType::F16:
234 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
235 static_cast<armnn::Half*>(memory));
236 break;
Keith Davisa8565012020-02-14 12:22:40 +0000237 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
238 case arm_compute::DataType::QASYMM8_SIGNED:
239 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
240 static_cast<int8_t*>(memory));
241 break;
James Conroyd2aa85e2019-07-01 17:12:40 +0100242 case arm_compute::DataType::S16:
243 case arm_compute::DataType::QSYMM16:
244 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
245 static_cast<int16_t*>(memory));
246 break;
James Conroy2dc05722019-09-19 17:00:31 +0100247 case arm_compute::DataType::S32:
248 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
249 static_cast<int32_t*>(memory));
250 break;
David Beck09e2f272018-10-30 11:38:41 +0000251 default:
252 {
253 throw armnn::UnimplementedException();
254 }
255 }
256 const_cast<ClSubTensorHandle*>(this)->Unmap();
257 }
258
259 // Only used for testing
260 void CopyInFrom(const void* memory) override
261 {
262 this->Map(true);
263 switch(this->GetDataType())
264 {
265 case arm_compute::DataType::F32:
266 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
267 this->GetTensor());
268 break;
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +0000269 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000270 case arm_compute::DataType::QASYMM8:
271 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
272 this->GetTensor());
273 break;
274 case arm_compute::DataType::F16:
275 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
276 this->GetTensor());
277 break;
Keith Davisa8565012020-02-14 12:22:40 +0000278 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
279 case arm_compute::DataType::QASYMM8_SIGNED:
280 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
281 this->GetTensor());
282 break;
James Conroyd2aa85e2019-07-01 17:12:40 +0100283 case arm_compute::DataType::S16:
284 case arm_compute::DataType::QSYMM16:
285 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
286 this->GetTensor());
287 break;
James Conroy2dc05722019-09-19 17:00:31 +0100288 case arm_compute::DataType::S32:
289 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
290 this->GetTensor());
291 break;
David Beck09e2f272018-10-30 11:38:41 +0000292 default:
293 {
294 throw armnn::UnimplementedException();
295 }
296 }
297 this->Unmap();
298 }
299
telsoa01c577f2c2018-08-31 09:22:23 +0100300 mutable arm_compute::CLSubTensor m_Tensor;
301 ITensorHandle* parentHandle = nullptr;
telsoa014fcda012018-03-09 14:13:49 +0000302};
303
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +0000304} // namespace armnn