blob: 59a6bee7f54a9c0c1e10f78cb451799584db7172 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00007#include <backendsCommon/OutputHandler.hpp>
8#include <aclCommon/ArmComputeTensorUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009
David Beck09e2f272018-10-30 11:38:41 +000010#include <Half.hpp>
11
telsoa014fcda012018-03-09 14:13:49 +000012#include <arm_compute/runtime/CL/CLTensor.h>
13#include <arm_compute/runtime/CL/CLSubTensor.h>
telsoa01c577f2c2018-08-31 09:22:23 +010014#include <arm_compute/runtime/CL/CLMemoryGroup.h>
15#include <arm_compute/runtime/IMemoryGroup.h>
telsoa014fcda012018-03-09 14:13:49 +000016#include <arm_compute/core/TensorShape.h>
17#include <arm_compute/core/Coordinates.h>
18
telsoa01c577f2c2018-08-31 09:22:23 +010019#include <boost/polymorphic_pointer_cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000020
21namespace armnn
22{
23
24
25class IClTensorHandle : public ITensorHandle
26{
27public:
28 virtual arm_compute::ICLTensor& GetTensor() = 0;
29 virtual arm_compute::ICLTensor const& GetTensor() const = 0;
telsoa014fcda012018-03-09 14:13:49 +000030 virtual arm_compute::DataType GetDataType() const = 0;
telsoa01c577f2c2018-08-31 09:22:23 +010031 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) = 0;
telsoa014fcda012018-03-09 14:13:49 +000032};
33
34class ClTensorHandle : public IClTensorHandle
35{
36public:
37 ClTensorHandle(const TensorInfo& tensorInfo)
38 {
39 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
40 }
41
Francis Murtagh351d13d2018-09-24 15:01:18 +010042 ClTensorHandle(const TensorInfo& tensorInfo, DataLayout dataLayout)
43 {
44 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
45 }
46
telsoa014fcda012018-03-09 14:13:49 +000047 arm_compute::CLTensor& GetTensor() override { return m_Tensor; }
48 arm_compute::CLTensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +010049 virtual void Allocate() override {armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);}
telsoa014fcda012018-03-09 14:13:49 +000050
telsoa01c577f2c2018-08-31 09:22:23 +010051 virtual void Manage() override
52 {
53 assert(m_MemoryGroup != nullptr);
54 m_MemoryGroup->manage(&m_Tensor);
55 }
telsoa014fcda012018-03-09 14:13:49 +000056
telsoa01c577f2c2018-08-31 09:22:23 +010057 virtual const void* Map(bool blocking = true) const override
58 {
59 const_cast<arm_compute::CLTensor*>(&m_Tensor)->map(blocking);
60 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
61 }
62 virtual void Unmap() const override { const_cast<arm_compute::CLTensor*>(&m_Tensor)->unmap(); }
63
telsoa01c577f2c2018-08-31 09:22:23 +010064 virtual ITensorHandle* GetParent() const override { return nullptr; }
telsoa014fcda012018-03-09 14:13:49 +000065
66 virtual arm_compute::DataType GetDataType() const override
67 {
68 return m_Tensor.info()->data_type();
69 }
70
telsoa01c577f2c2018-08-31 09:22:23 +010071 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
72 {
73 m_MemoryGroup = boost::polymorphic_pointer_downcast<arm_compute::CLMemoryGroup>(memoryGroup);
74 }
75
76 TensorShape GetStrides() const override
77 {
78 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
79 }
80
81 TensorShape GetShape() const override
82 {
83 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
84 }
David Beck09e2f272018-10-30 11:38:41 +000085
telsoa014fcda012018-03-09 14:13:49 +000086private:
David Beck09e2f272018-10-30 11:38:41 +000087 // Only used for testing
88 void CopyOutTo(void* memory) const override
89 {
90 const_cast<armnn::ClTensorHandle*>(this)->Map(true);
91 switch(this->GetDataType())
92 {
93 case arm_compute::DataType::F32:
94 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
95 static_cast<float*>(memory));
96 break;
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +000097 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +000098 case arm_compute::DataType::QASYMM8:
99 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
100 static_cast<uint8_t*>(memory));
101 break;
102 case arm_compute::DataType::F16:
103 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
104 static_cast<armnn::Half*>(memory));
105 break;
106 default:
107 {
108 throw armnn::UnimplementedException();
109 }
110 }
111 const_cast<armnn::ClTensorHandle*>(this)->Unmap();
112 }
113
114 // Only used for testing
115 void CopyInFrom(const void* memory) override
116 {
117 this->Map(true);
118 switch(this->GetDataType())
119 {
120 case arm_compute::DataType::F32:
121 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
122 this->GetTensor());
123 break;
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +0000124 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000125 case arm_compute::DataType::QASYMM8:
126 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
127 this->GetTensor());
128 break;
129 case arm_compute::DataType::F16:
130 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
131 this->GetTensor());
132 break;
133 default:
134 {
135 throw armnn::UnimplementedException();
136 }
137 }
138 this->Unmap();
139 }
140
telsoa014fcda012018-03-09 14:13:49 +0000141 arm_compute::CLTensor m_Tensor;
telsoa01c577f2c2018-08-31 09:22:23 +0100142 std::shared_ptr<arm_compute::CLMemoryGroup> m_MemoryGroup;
telsoa014fcda012018-03-09 14:13:49 +0000143};
144
145class ClSubTensorHandle : public IClTensorHandle
146{
147public:
telsoa01c577f2c2018-08-31 09:22:23 +0100148 ClSubTensorHandle(IClTensorHandle* parent,
149 const arm_compute::TensorShape& shape,
150 const arm_compute::Coordinates& coords)
151 : m_Tensor(&parent->GetTensor(), shape, coords)
telsoa014fcda012018-03-09 14:13:49 +0000152 {
telsoa01c577f2c2018-08-31 09:22:23 +0100153 parentHandle = parent;
telsoa014fcda012018-03-09 14:13:49 +0000154 }
155
156 arm_compute::CLSubTensor& GetTensor() override { return m_Tensor; }
157 arm_compute::CLSubTensor const& GetTensor() const override { return m_Tensor; }
telsoa014fcda012018-03-09 14:13:49 +0000158
telsoa01c577f2c2018-08-31 09:22:23 +0100159 virtual void Allocate() override {}
160 virtual void Manage() override {}
telsoa014fcda012018-03-09 14:13:49 +0000161
telsoa01c577f2c2018-08-31 09:22:23 +0100162 virtual const void* Map(bool blocking = true) const override
163 {
164 const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->map(blocking);
165 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
166 }
167 virtual void Unmap() const override { const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->unmap(); }
168
telsoa01c577f2c2018-08-31 09:22:23 +0100169 virtual ITensorHandle* GetParent() const override { return parentHandle; }
telsoa014fcda012018-03-09 14:13:49 +0000170
171 virtual arm_compute::DataType GetDataType() const override
172 {
173 return m_Tensor.info()->data_type();
174 }
175
telsoa01c577f2c2018-08-31 09:22:23 +0100176 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
177
178 TensorShape GetStrides() const override
179 {
180 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
181 }
182
183 TensorShape GetShape() const override
184 {
185 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
186 }
187
telsoa014fcda012018-03-09 14:13:49 +0000188private:
David Beck09e2f272018-10-30 11:38:41 +0000189 // Only used for testing
190 void CopyOutTo(void* memory) const override
191 {
192 const_cast<ClSubTensorHandle*>(this)->Map(true);
193 switch(this->GetDataType())
194 {
195 case arm_compute::DataType::F32:
196 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
197 static_cast<float*>(memory));
198 break;
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +0000199 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000200 case arm_compute::DataType::QASYMM8:
201 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
202 static_cast<uint8_t*>(memory));
203 break;
204 case arm_compute::DataType::F16:
205 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
206 static_cast<armnn::Half*>(memory));
207 break;
208 default:
209 {
210 throw armnn::UnimplementedException();
211 }
212 }
213 const_cast<ClSubTensorHandle*>(this)->Unmap();
214 }
215
216 // Only used for testing
217 void CopyInFrom(const void* memory) override
218 {
219 this->Map(true);
220 switch(this->GetDataType())
221 {
222 case arm_compute::DataType::F32:
223 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
224 this->GetTensor());
225 break;
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +0000226 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000227 case arm_compute::DataType::QASYMM8:
228 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
229 this->GetTensor());
230 break;
231 case arm_compute::DataType::F16:
232 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
233 this->GetTensor());
234 break;
235 default:
236 {
237 throw armnn::UnimplementedException();
238 }
239 }
240 this->Unmap();
241 }
242
telsoa01c577f2c2018-08-31 09:22:23 +0100243 mutable arm_compute::CLSubTensor m_Tensor;
244 ITensorHandle* parentHandle = nullptr;
telsoa014fcda012018-03-09 14:13:49 +0000245};
246
Nattapat Chaimanowong8c76cc12019-01-23 09:59:14 +0000247} // namespace armnn