blob: 2423a8bbcb0d03965ede153d0d5d034f5a50f425 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00007#include <backendsCommon/OutputHandler.hpp>
Derek Lambertic81855f2019-06-13 17:34:19 +01008#include <aclCommon/ArmComputeTensorHandle.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <aclCommon/ArmComputeTensorUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000010
David Beck09e2f272018-10-30 11:38:41 +000011#include <Half.hpp>
12
telsoa014fcda012018-03-09 14:13:49 +000013#include <arm_compute/runtime/CL/CLTensor.h>
14#include <arm_compute/runtime/CL/CLSubTensor.h>
telsoa01c577f2c2018-08-31 09:22:23 +010015#include <arm_compute/runtime/IMemoryGroup.h>
Narumol Prangnawarat680f9912019-10-01 11:32:10 +010016#include <arm_compute/runtime/MemoryGroup.h>
telsoa014fcda012018-03-09 14:13:49 +000017#include <arm_compute/core/TensorShape.h>
18#include <arm_compute/core/Coordinates.h>
19
telsoa01c577f2c2018-08-31 09:22:23 +010020#include <boost/polymorphic_pointer_cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000021
22namespace armnn
23{
24
25
Derek Lambertic81855f2019-06-13 17:34:19 +010026class IClTensorHandle : public IAclTensorHandle
telsoa014fcda012018-03-09 14:13:49 +000027{
28public:
29 virtual arm_compute::ICLTensor& GetTensor() = 0;
30 virtual arm_compute::ICLTensor const& GetTensor() const = 0;
telsoa014fcda012018-03-09 14:13:49 +000031 virtual arm_compute::DataType GetDataType() const = 0;
telsoa01c577f2c2018-08-31 09:22:23 +010032 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) = 0;
telsoa014fcda012018-03-09 14:13:49 +000033};
34
35class ClTensorHandle : public IClTensorHandle
36{
37public:
38 ClTensorHandle(const TensorInfo& tensorInfo)
39 {
40 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
41 }
42
Francis Murtagh351d13d2018-09-24 15:01:18 +010043 ClTensorHandle(const TensorInfo& tensorInfo, DataLayout dataLayout)
44 {
45 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
46 }
47
telsoa014fcda012018-03-09 14:13:49 +000048 arm_compute::CLTensor& GetTensor() override { return m_Tensor; }
49 arm_compute::CLTensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +010050 virtual void Allocate() override {armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);}
telsoa014fcda012018-03-09 14:13:49 +000051
telsoa01c577f2c2018-08-31 09:22:23 +010052 virtual void Manage() override
53 {
54 assert(m_MemoryGroup != nullptr);
55 m_MemoryGroup->manage(&m_Tensor);
56 }
telsoa014fcda012018-03-09 14:13:49 +000057
telsoa01c577f2c2018-08-31 09:22:23 +010058 virtual const void* Map(bool blocking = true) const override
59 {
60 const_cast<arm_compute::CLTensor*>(&m_Tensor)->map(blocking);
61 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
62 }
Matthew Bentham7c1603a2019-06-21 17:22:23 +010063
telsoa01c577f2c2018-08-31 09:22:23 +010064 virtual void Unmap() const override { const_cast<arm_compute::CLTensor*>(&m_Tensor)->unmap(); }
65
telsoa01c577f2c2018-08-31 09:22:23 +010066 virtual ITensorHandle* GetParent() const override { return nullptr; }
telsoa014fcda012018-03-09 14:13:49 +000067
68 virtual arm_compute::DataType GetDataType() const override
69 {
70 return m_Tensor.info()->data_type();
71 }
72
telsoa01c577f2c2018-08-31 09:22:23 +010073 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
74 {
Narumol Prangnawarat680f9912019-10-01 11:32:10 +010075 m_MemoryGroup = boost::polymorphic_pointer_downcast<arm_compute::MemoryGroup>(memoryGroup);
telsoa01c577f2c2018-08-31 09:22:23 +010076 }
77
78 TensorShape GetStrides() const override
79 {
80 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
81 }
82
83 TensorShape GetShape() const override
84 {
85 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
86 }
David Beck09e2f272018-10-30 11:38:41 +000087
telsoa014fcda012018-03-09 14:13:49 +000088private:
David Beck09e2f272018-10-30 11:38:41 +000089 // Only used for testing
90 void CopyOutTo(void* memory) const override
91 {
92 const_cast<armnn::ClTensorHandle*>(this)->Map(true);
93 switch(this->GetDataType())
94 {
95 case arm_compute::DataType::F32:
96 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
97 static_cast<float*>(memory));
98 break;
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +000099 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000100 case arm_compute::DataType::QASYMM8:
101 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
102 static_cast<uint8_t*>(memory));
103 break;
104 case arm_compute::DataType::F16:
105 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
106 static_cast<armnn::Half*>(memory));
107 break;
James Conroyd2aa85e2019-07-01 17:12:40 +0100108 case arm_compute::DataType::S16:
109 case arm_compute::DataType::QSYMM16:
110 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
111 static_cast<int16_t*>(memory));
112 break;
James Conroy2dc05722019-09-19 17:00:31 +0100113 case arm_compute::DataType::S32:
114 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
115 static_cast<int32_t*>(memory));
116 break;
David Beck09e2f272018-10-30 11:38:41 +0000117 default:
118 {
119 throw armnn::UnimplementedException();
120 }
121 }
122 const_cast<armnn::ClTensorHandle*>(this)->Unmap();
123 }
124
125 // Only used for testing
126 void CopyInFrom(const void* memory) override
127 {
128 this->Map(true);
129 switch(this->GetDataType())
130 {
131 case arm_compute::DataType::F32:
132 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
133 this->GetTensor());
134 break;
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +0000135 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000136 case arm_compute::DataType::QASYMM8:
137 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
138 this->GetTensor());
139 break;
140 case arm_compute::DataType::F16:
141 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
142 this->GetTensor());
143 break;
James Conroyd2aa85e2019-07-01 17:12:40 +0100144 case arm_compute::DataType::S16:
145 case arm_compute::DataType::QSYMM16:
146 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
147 this->GetTensor());
148 break;
James Conroy2dc05722019-09-19 17:00:31 +0100149 case arm_compute::DataType::S32:
150 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
151 this->GetTensor());
152 break;
David Beck09e2f272018-10-30 11:38:41 +0000153 default:
154 {
155 throw armnn::UnimplementedException();
156 }
157 }
158 this->Unmap();
159 }
160
telsoa014fcda012018-03-09 14:13:49 +0000161 arm_compute::CLTensor m_Tensor;
Narumol Prangnawarat680f9912019-10-01 11:32:10 +0100162 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
telsoa014fcda012018-03-09 14:13:49 +0000163};
164
165class ClSubTensorHandle : public IClTensorHandle
166{
167public:
telsoa01c577f2c2018-08-31 09:22:23 +0100168 ClSubTensorHandle(IClTensorHandle* parent,
169 const arm_compute::TensorShape& shape,
170 const arm_compute::Coordinates& coords)
171 : m_Tensor(&parent->GetTensor(), shape, coords)
telsoa014fcda012018-03-09 14:13:49 +0000172 {
telsoa01c577f2c2018-08-31 09:22:23 +0100173 parentHandle = parent;
telsoa014fcda012018-03-09 14:13:49 +0000174 }
175
176 arm_compute::CLSubTensor& GetTensor() override { return m_Tensor; }
177 arm_compute::CLSubTensor const& GetTensor() const override { return m_Tensor; }
telsoa014fcda012018-03-09 14:13:49 +0000178
telsoa01c577f2c2018-08-31 09:22:23 +0100179 virtual void Allocate() override {}
180 virtual void Manage() override {}
telsoa014fcda012018-03-09 14:13:49 +0000181
telsoa01c577f2c2018-08-31 09:22:23 +0100182 virtual const void* Map(bool blocking = true) const override
183 {
184 const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->map(blocking);
185 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
186 }
187 virtual void Unmap() const override { const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->unmap(); }
188
telsoa01c577f2c2018-08-31 09:22:23 +0100189 virtual ITensorHandle* GetParent() const override { return parentHandle; }
telsoa014fcda012018-03-09 14:13:49 +0000190
191 virtual arm_compute::DataType GetDataType() const override
192 {
193 return m_Tensor.info()->data_type();
194 }
195
telsoa01c577f2c2018-08-31 09:22:23 +0100196 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
197
198 TensorShape GetStrides() const override
199 {
200 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
201 }
202
203 TensorShape GetShape() const override
204 {
205 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
206 }
207
telsoa014fcda012018-03-09 14:13:49 +0000208private:
David Beck09e2f272018-10-30 11:38:41 +0000209 // Only used for testing
210 void CopyOutTo(void* memory) const override
211 {
212 const_cast<ClSubTensorHandle*>(this)->Map(true);
213 switch(this->GetDataType())
214 {
215 case arm_compute::DataType::F32:
216 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
217 static_cast<float*>(memory));
218 break;
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +0000219 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000220 case arm_compute::DataType::QASYMM8:
221 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
222 static_cast<uint8_t*>(memory));
223 break;
224 case arm_compute::DataType::F16:
225 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
226 static_cast<armnn::Half*>(memory));
227 break;
James Conroyd2aa85e2019-07-01 17:12:40 +0100228 case arm_compute::DataType::S16:
229 case arm_compute::DataType::QSYMM16:
230 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
231 static_cast<int16_t*>(memory));
232 break;
James Conroy2dc05722019-09-19 17:00:31 +0100233 case arm_compute::DataType::S32:
234 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
235 static_cast<int32_t*>(memory));
236 break;
David Beck09e2f272018-10-30 11:38:41 +0000237 default:
238 {
239 throw armnn::UnimplementedException();
240 }
241 }
242 const_cast<ClSubTensorHandle*>(this)->Unmap();
243 }
244
245 // Only used for testing
246 void CopyInFrom(const void* memory) override
247 {
248 this->Map(true);
249 switch(this->GetDataType())
250 {
251 case arm_compute::DataType::F32:
252 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
253 this->GetTensor());
254 break;
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +0000255 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000256 case arm_compute::DataType::QASYMM8:
257 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
258 this->GetTensor());
259 break;
260 case arm_compute::DataType::F16:
261 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
262 this->GetTensor());
263 break;
James Conroyd2aa85e2019-07-01 17:12:40 +0100264 case arm_compute::DataType::S16:
265 case arm_compute::DataType::QSYMM16:
266 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
267 this->GetTensor());
268 break;
James Conroy2dc05722019-09-19 17:00:31 +0100269 case arm_compute::DataType::S32:
270 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
271 this->GetTensor());
272 break;
David Beck09e2f272018-10-30 11:38:41 +0000273 default:
274 {
275 throw armnn::UnimplementedException();
276 }
277 }
278 this->Unmap();
279 }
280
telsoa01c577f2c2018-08-31 09:22:23 +0100281 mutable arm_compute::CLSubTensor m_Tensor;
282 ITensorHandle* parentHandle = nullptr;
telsoa014fcda012018-03-09 14:13:49 +0000283};
284
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +0000285} // namespace armnn