blob: f791ee8fc9429bebad37073b625990c44d6f6e2d [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00007#include <backendsCommon/OutputHandler.hpp>
8#include <aclCommon/ArmComputeTensorUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009
David Beck09e2f272018-10-30 11:38:41 +000010#include <Half.hpp>
11
telsoa014fcda012018-03-09 14:13:49 +000012#include <arm_compute/runtime/CL/CLTensor.h>
13#include <arm_compute/runtime/CL/CLSubTensor.h>
telsoa01c577f2c2018-08-31 09:22:23 +010014#include <arm_compute/runtime/CL/CLMemoryGroup.h>
15#include <arm_compute/runtime/IMemoryGroup.h>
telsoa014fcda012018-03-09 14:13:49 +000016#include <arm_compute/core/TensorShape.h>
17#include <arm_compute/core/Coordinates.h>
18
telsoa01c577f2c2018-08-31 09:22:23 +010019#include <boost/polymorphic_pointer_cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000020
21namespace armnn
22{
23
24
25class IClTensorHandle : public ITensorHandle
26{
27public:
28 virtual arm_compute::ICLTensor& GetTensor() = 0;
29 virtual arm_compute::ICLTensor const& GetTensor() const = 0;
telsoa014fcda012018-03-09 14:13:49 +000030 virtual arm_compute::DataType GetDataType() const = 0;
telsoa01c577f2c2018-08-31 09:22:23 +010031 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) = 0;
telsoa014fcda012018-03-09 14:13:49 +000032};
33
34class ClTensorHandle : public IClTensorHandle
35{
36public:
37 ClTensorHandle(const TensorInfo& tensorInfo)
38 {
39 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
40 }
41
Francis Murtagh351d13d2018-09-24 15:01:18 +010042 ClTensorHandle(const TensorInfo& tensorInfo, DataLayout dataLayout)
43 {
44 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
45 }
46
telsoa014fcda012018-03-09 14:13:49 +000047 arm_compute::CLTensor& GetTensor() override { return m_Tensor; }
48 arm_compute::CLTensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +010049 virtual void Allocate() override {armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);}
telsoa014fcda012018-03-09 14:13:49 +000050
telsoa01c577f2c2018-08-31 09:22:23 +010051 virtual void Manage() override
52 {
53 assert(m_MemoryGroup != nullptr);
54 m_MemoryGroup->manage(&m_Tensor);
55 }
telsoa014fcda012018-03-09 14:13:49 +000056
telsoa01c577f2c2018-08-31 09:22:23 +010057 virtual const void* Map(bool blocking = true) const override
58 {
59 const_cast<arm_compute::CLTensor*>(&m_Tensor)->map(blocking);
60 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
61 }
62 virtual void Unmap() const override { const_cast<arm_compute::CLTensor*>(&m_Tensor)->unmap(); }
63
telsoa01c577f2c2018-08-31 09:22:23 +010064 virtual ITensorHandle* GetParent() const override { return nullptr; }
telsoa014fcda012018-03-09 14:13:49 +000065
66 virtual arm_compute::DataType GetDataType() const override
67 {
68 return m_Tensor.info()->data_type();
69 }
70
telsoa01c577f2c2018-08-31 09:22:23 +010071 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
72 {
73 m_MemoryGroup = boost::polymorphic_pointer_downcast<arm_compute::CLMemoryGroup>(memoryGroup);
74 }
75
76 TensorShape GetStrides() const override
77 {
78 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
79 }
80
81 TensorShape GetShape() const override
82 {
83 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
84 }
David Beck09e2f272018-10-30 11:38:41 +000085
telsoa014fcda012018-03-09 14:13:49 +000086private:
David Beck09e2f272018-10-30 11:38:41 +000087 // Only used for testing
88 void CopyOutTo(void* memory) const override
89 {
90 const_cast<armnn::ClTensorHandle*>(this)->Map(true);
91 switch(this->GetDataType())
92 {
93 case arm_compute::DataType::F32:
94 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
95 static_cast<float*>(memory));
96 break;
97 case arm_compute::DataType::QASYMM8:
98 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
99 static_cast<uint8_t*>(memory));
100 break;
101 case arm_compute::DataType::F16:
102 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
103 static_cast<armnn::Half*>(memory));
104 break;
105 default:
106 {
107 throw armnn::UnimplementedException();
108 }
109 }
110 const_cast<armnn::ClTensorHandle*>(this)->Unmap();
111 }
112
113 // Only used for testing
114 void CopyInFrom(const void* memory) override
115 {
116 this->Map(true);
117 switch(this->GetDataType())
118 {
119 case arm_compute::DataType::F32:
120 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
121 this->GetTensor());
122 break;
123 case arm_compute::DataType::QASYMM8:
124 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
125 this->GetTensor());
126 break;
127 case arm_compute::DataType::F16:
128 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
129 this->GetTensor());
130 break;
131 default:
132 {
133 throw armnn::UnimplementedException();
134 }
135 }
136 this->Unmap();
137 }
138
telsoa014fcda012018-03-09 14:13:49 +0000139 arm_compute::CLTensor m_Tensor;
telsoa01c577f2c2018-08-31 09:22:23 +0100140 std::shared_ptr<arm_compute::CLMemoryGroup> m_MemoryGroup;
telsoa014fcda012018-03-09 14:13:49 +0000141};
142
143class ClSubTensorHandle : public IClTensorHandle
144{
145public:
telsoa01c577f2c2018-08-31 09:22:23 +0100146 ClSubTensorHandle(IClTensorHandle* parent,
147 const arm_compute::TensorShape& shape,
148 const arm_compute::Coordinates& coords)
149 : m_Tensor(&parent->GetTensor(), shape, coords)
telsoa014fcda012018-03-09 14:13:49 +0000150 {
telsoa01c577f2c2018-08-31 09:22:23 +0100151 parentHandle = parent;
telsoa014fcda012018-03-09 14:13:49 +0000152 }
153
154 arm_compute::CLSubTensor& GetTensor() override { return m_Tensor; }
155 arm_compute::CLSubTensor const& GetTensor() const override { return m_Tensor; }
telsoa014fcda012018-03-09 14:13:49 +0000156
telsoa01c577f2c2018-08-31 09:22:23 +0100157 virtual void Allocate() override {}
158 virtual void Manage() override {}
telsoa014fcda012018-03-09 14:13:49 +0000159
telsoa01c577f2c2018-08-31 09:22:23 +0100160 virtual const void* Map(bool blocking = true) const override
161 {
162 const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->map(blocking);
163 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
164 }
165 virtual void Unmap() const override { const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->unmap(); }
166
telsoa01c577f2c2018-08-31 09:22:23 +0100167 virtual ITensorHandle* GetParent() const override { return parentHandle; }
telsoa014fcda012018-03-09 14:13:49 +0000168
169 virtual arm_compute::DataType GetDataType() const override
170 {
171 return m_Tensor.info()->data_type();
172 }
173
telsoa01c577f2c2018-08-31 09:22:23 +0100174 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
175
176 TensorShape GetStrides() const override
177 {
178 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
179 }
180
181 TensorShape GetShape() const override
182 {
183 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
184 }
185
telsoa014fcda012018-03-09 14:13:49 +0000186private:
David Beck09e2f272018-10-30 11:38:41 +0000187 // Only used for testing
188 void CopyOutTo(void* memory) const override
189 {
190 const_cast<ClSubTensorHandle*>(this)->Map(true);
191 switch(this->GetDataType())
192 {
193 case arm_compute::DataType::F32:
194 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
195 static_cast<float*>(memory));
196 break;
197 case arm_compute::DataType::QASYMM8:
198 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
199 static_cast<uint8_t*>(memory));
200 break;
201 case arm_compute::DataType::F16:
202 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
203 static_cast<armnn::Half*>(memory));
204 break;
205 default:
206 {
207 throw armnn::UnimplementedException();
208 }
209 }
210 const_cast<ClSubTensorHandle*>(this)->Unmap();
211 }
212
213 // Only used for testing
214 void CopyInFrom(const void* memory) override
215 {
216 this->Map(true);
217 switch(this->GetDataType())
218 {
219 case arm_compute::DataType::F32:
220 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
221 this->GetTensor());
222 break;
223 case arm_compute::DataType::QASYMM8:
224 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
225 this->GetTensor());
226 break;
227 case arm_compute::DataType::F16:
228 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
229 this->GetTensor());
230 break;
231 default:
232 {
233 throw armnn::UnimplementedException();
234 }
235 }
236 this->Unmap();
237 }
238
telsoa01c577f2c2018-08-31 09:22:23 +0100239 mutable arm_compute::CLSubTensor m_Tensor;
240 ITensorHandle* parentHandle = nullptr;
telsoa014fcda012018-03-09 14:13:49 +0000241};
242
David Beck09e2f272018-10-30 11:38:41 +0000243} // namespace armnn