blob: d08b79f9a6faa95fda7f612958d2fe7a0d91c3a7 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00007#include <backendsCommon/OutputHandler.hpp>
Derek Lambertic81855f2019-06-13 17:34:19 +01008#include <aclCommon/ArmComputeTensorHandle.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <aclCommon/ArmComputeTensorUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000010
David Beck09e2f272018-10-30 11:38:41 +000011#include <Half.hpp>
12
telsoa014fcda012018-03-09 14:13:49 +000013#include <arm_compute/runtime/CL/CLTensor.h>
14#include <arm_compute/runtime/CL/CLSubTensor.h>
telsoa01c577f2c2018-08-31 09:22:23 +010015#include <arm_compute/runtime/CL/CLMemoryGroup.h>
16#include <arm_compute/runtime/IMemoryGroup.h>
telsoa014fcda012018-03-09 14:13:49 +000017#include <arm_compute/core/TensorShape.h>
18#include <arm_compute/core/Coordinates.h>
19
telsoa01c577f2c2018-08-31 09:22:23 +010020#include <boost/polymorphic_pointer_cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000021
22namespace armnn
23{
24
25
Derek Lambertic81855f2019-06-13 17:34:19 +010026class IClTensorHandle : public IAclTensorHandle
telsoa014fcda012018-03-09 14:13:49 +000027{
28public:
29 virtual arm_compute::ICLTensor& GetTensor() = 0;
30 virtual arm_compute::ICLTensor const& GetTensor() const = 0;
telsoa014fcda012018-03-09 14:13:49 +000031 virtual arm_compute::DataType GetDataType() const = 0;
telsoa01c577f2c2018-08-31 09:22:23 +010032 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) = 0;
telsoa014fcda012018-03-09 14:13:49 +000033};
34
35class ClTensorHandle : public IClTensorHandle
36{
37public:
38 ClTensorHandle(const TensorInfo& tensorInfo)
39 {
40 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
41 }
42
Francis Murtagh351d13d2018-09-24 15:01:18 +010043 ClTensorHandle(const TensorInfo& tensorInfo, DataLayout dataLayout)
44 {
45 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
46 }
47
telsoa014fcda012018-03-09 14:13:49 +000048 arm_compute::CLTensor& GetTensor() override { return m_Tensor; }
49 arm_compute::CLTensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +010050 virtual void Allocate() override {armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);}
telsoa014fcda012018-03-09 14:13:49 +000051
telsoa01c577f2c2018-08-31 09:22:23 +010052 virtual void Manage() override
53 {
54 assert(m_MemoryGroup != nullptr);
55 m_MemoryGroup->manage(&m_Tensor);
56 }
telsoa014fcda012018-03-09 14:13:49 +000057
telsoa01c577f2c2018-08-31 09:22:23 +010058 virtual const void* Map(bool blocking = true) const override
59 {
60 const_cast<arm_compute::CLTensor*>(&m_Tensor)->map(blocking);
61 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
62 }
Matthew Bentham7c1603a2019-06-21 17:22:23 +010063
telsoa01c577f2c2018-08-31 09:22:23 +010064 virtual void Unmap() const override { const_cast<arm_compute::CLTensor*>(&m_Tensor)->unmap(); }
65
telsoa01c577f2c2018-08-31 09:22:23 +010066 virtual ITensorHandle* GetParent() const override { return nullptr; }
telsoa014fcda012018-03-09 14:13:49 +000067
68 virtual arm_compute::DataType GetDataType() const override
69 {
70 return m_Tensor.info()->data_type();
71 }
72
telsoa01c577f2c2018-08-31 09:22:23 +010073 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
74 {
75 m_MemoryGroup = boost::polymorphic_pointer_downcast<arm_compute::CLMemoryGroup>(memoryGroup);
76 }
77
78 TensorShape GetStrides() const override
79 {
80 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
81 }
82
83 TensorShape GetShape() const override
84 {
85 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
86 }
David Beck09e2f272018-10-30 11:38:41 +000087
telsoa014fcda012018-03-09 14:13:49 +000088private:
David Beck09e2f272018-10-30 11:38:41 +000089 // Only used for testing
90 void CopyOutTo(void* memory) const override
91 {
92 const_cast<armnn::ClTensorHandle*>(this)->Map(true);
93 switch(this->GetDataType())
94 {
95 case arm_compute::DataType::F32:
96 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
97 static_cast<float*>(memory));
98 break;
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +000099 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000100 case arm_compute::DataType::QASYMM8:
101 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
102 static_cast<uint8_t*>(memory));
103 break;
104 case arm_compute::DataType::F16:
105 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
106 static_cast<armnn::Half*>(memory));
107 break;
James Conroyd2aa85e2019-07-01 17:12:40 +0100108 case arm_compute::DataType::S16:
109 case arm_compute::DataType::QSYMM16:
110 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
111 static_cast<int16_t*>(memory));
112 break;
David Beck09e2f272018-10-30 11:38:41 +0000113 default:
114 {
115 throw armnn::UnimplementedException();
116 }
117 }
118 const_cast<armnn::ClTensorHandle*>(this)->Unmap();
119 }
120
121 // Only used for testing
122 void CopyInFrom(const void* memory) override
123 {
124 this->Map(true);
125 switch(this->GetDataType())
126 {
127 case arm_compute::DataType::F32:
128 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
129 this->GetTensor());
130 break;
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +0000131 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000132 case arm_compute::DataType::QASYMM8:
133 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
134 this->GetTensor());
135 break;
136 case arm_compute::DataType::F16:
137 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
138 this->GetTensor());
139 break;
James Conroyd2aa85e2019-07-01 17:12:40 +0100140 case arm_compute::DataType::S16:
141 case arm_compute::DataType::QSYMM16:
142 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
143 this->GetTensor());
144 break;
David Beck09e2f272018-10-30 11:38:41 +0000145 default:
146 {
147 throw armnn::UnimplementedException();
148 }
149 }
150 this->Unmap();
151 }
152
telsoa014fcda012018-03-09 14:13:49 +0000153 arm_compute::CLTensor m_Tensor;
telsoa01c577f2c2018-08-31 09:22:23 +0100154 std::shared_ptr<arm_compute::CLMemoryGroup> m_MemoryGroup;
telsoa014fcda012018-03-09 14:13:49 +0000155};
156
157class ClSubTensorHandle : public IClTensorHandle
158{
159public:
telsoa01c577f2c2018-08-31 09:22:23 +0100160 ClSubTensorHandle(IClTensorHandle* parent,
161 const arm_compute::TensorShape& shape,
162 const arm_compute::Coordinates& coords)
163 : m_Tensor(&parent->GetTensor(), shape, coords)
telsoa014fcda012018-03-09 14:13:49 +0000164 {
telsoa01c577f2c2018-08-31 09:22:23 +0100165 parentHandle = parent;
telsoa014fcda012018-03-09 14:13:49 +0000166 }
167
168 arm_compute::CLSubTensor& GetTensor() override { return m_Tensor; }
169 arm_compute::CLSubTensor const& GetTensor() const override { return m_Tensor; }
telsoa014fcda012018-03-09 14:13:49 +0000170
telsoa01c577f2c2018-08-31 09:22:23 +0100171 virtual void Allocate() override {}
172 virtual void Manage() override {}
telsoa014fcda012018-03-09 14:13:49 +0000173
telsoa01c577f2c2018-08-31 09:22:23 +0100174 virtual const void* Map(bool blocking = true) const override
175 {
176 const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->map(blocking);
177 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
178 }
179 virtual void Unmap() const override { const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->unmap(); }
180
telsoa01c577f2c2018-08-31 09:22:23 +0100181 virtual ITensorHandle* GetParent() const override { return parentHandle; }
telsoa014fcda012018-03-09 14:13:49 +0000182
183 virtual arm_compute::DataType GetDataType() const override
184 {
185 return m_Tensor.info()->data_type();
186 }
187
telsoa01c577f2c2018-08-31 09:22:23 +0100188 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
189
190 TensorShape GetStrides() const override
191 {
192 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
193 }
194
195 TensorShape GetShape() const override
196 {
197 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
198 }
199
telsoa014fcda012018-03-09 14:13:49 +0000200private:
David Beck09e2f272018-10-30 11:38:41 +0000201 // Only used for testing
202 void CopyOutTo(void* memory) const override
203 {
204 const_cast<ClSubTensorHandle*>(this)->Map(true);
205 switch(this->GetDataType())
206 {
207 case arm_compute::DataType::F32:
208 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
209 static_cast<float*>(memory));
210 break;
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +0000211 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000212 case arm_compute::DataType::QASYMM8:
213 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
214 static_cast<uint8_t*>(memory));
215 break;
216 case arm_compute::DataType::F16:
217 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
218 static_cast<armnn::Half*>(memory));
219 break;
James Conroyd2aa85e2019-07-01 17:12:40 +0100220 case arm_compute::DataType::S16:
221 case arm_compute::DataType::QSYMM16:
222 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
223 static_cast<int16_t*>(memory));
224 break;
David Beck09e2f272018-10-30 11:38:41 +0000225 default:
226 {
227 throw armnn::UnimplementedException();
228 }
229 }
230 const_cast<ClSubTensorHandle*>(this)->Unmap();
231 }
232
233 // Only used for testing
234 void CopyInFrom(const void* memory) override
235 {
236 this->Map(true);
237 switch(this->GetDataType())
238 {
239 case arm_compute::DataType::F32:
240 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
241 this->GetTensor());
242 break;
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +0000243 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000244 case arm_compute::DataType::QASYMM8:
245 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
246 this->GetTensor());
247 break;
248 case arm_compute::DataType::F16:
249 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
250 this->GetTensor());
251 break;
James Conroyd2aa85e2019-07-01 17:12:40 +0100252 case arm_compute::DataType::S16:
253 case arm_compute::DataType::QSYMM16:
254 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
255 this->GetTensor());
256 break;
David Beck09e2f272018-10-30 11:38:41 +0000257 default:
258 {
259 throw armnn::UnimplementedException();
260 }
261 }
262 this->Unmap();
263 }
264
telsoa01c577f2c2018-08-31 09:22:23 +0100265 mutable arm_compute::CLSubTensor m_Tensor;
266 ITensorHandle* parentHandle = nullptr;
telsoa014fcda012018-03-09 14:13:49 +0000267};
268
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +0000269} // namespace armnn