blob: b9720438276d9b062bd9b2ba4020bc071ab55b86 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00007#include <backendsCommon/OutputHandler.hpp>
8#include <aclCommon/ArmComputeTensorUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009
telsoa01c577f2c2018-08-31 09:22:23 +010010#include <arm_compute/runtime/MemoryGroup.h>
11#include <arm_compute/runtime/IMemoryGroup.h>
telsoa014fcda012018-03-09 14:13:49 +000012#include <arm_compute/runtime/Tensor.h>
13#include <arm_compute/runtime/SubTensor.h>
14#include <arm_compute/core/TensorShape.h>
15#include <arm_compute/core/Coordinates.h>
16
telsoa01c577f2c2018-08-31 09:22:23 +010017#include <boost/polymorphic_pointer_cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
19namespace armnn
20{
21
22class INeonTensorHandle : public ITensorHandle
23{
24public:
25 virtual arm_compute::ITensor& GetTensor() = 0;
26 virtual arm_compute::ITensor const& GetTensor() const = 0;
27 virtual arm_compute::DataType GetDataType() const = 0;
telsoa01c577f2c2018-08-31 09:22:23 +010028 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) = 0;
telsoa014fcda012018-03-09 14:13:49 +000029};
30
31class NeonTensorHandle : public INeonTensorHandle
32{
33public:
34 NeonTensorHandle(const TensorInfo& tensorInfo)
35 {
36 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
37 }
38
Francis Murtagh351d13d2018-09-24 15:01:18 +010039 NeonTensorHandle(const TensorInfo& tensorInfo, DataLayout dataLayout)
40 {
41 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
42 }
43
telsoa014fcda012018-03-09 14:13:49 +000044 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
45 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +010046
telsoa014fcda012018-03-09 14:13:49 +000047 virtual void Allocate() override
48 {
49 armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
50 };
51
telsoa01c577f2c2018-08-31 09:22:23 +010052 virtual void Manage() override
53 {
54 BOOST_ASSERT(m_MemoryGroup != nullptr);
55 m_MemoryGroup->manage(&m_Tensor);
56 }
57
telsoa01c577f2c2018-08-31 09:22:23 +010058 virtual ITensorHandle* GetParent() const override { return nullptr; }
59
telsoa014fcda012018-03-09 14:13:49 +000060 virtual arm_compute::DataType GetDataType() const override
61 {
62 return m_Tensor.info()->data_type();
63 }
64
telsoa01c577f2c2018-08-31 09:22:23 +010065 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
66 {
67 m_MemoryGroup = boost::polymorphic_pointer_downcast<arm_compute::MemoryGroup>(memoryGroup);
68 }
69
70 virtual const void* Map(bool /* blocking = true */) const override
71 {
72 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
73 }
74 virtual void Unmap() const override {}
75
76
77 TensorShape GetStrides() const override
78 {
79 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
80 }
81
82 TensorShape GetShape() const override
83 {
84 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
85 }
86
telsoa014fcda012018-03-09 14:13:49 +000087private:
David Beck09e2f272018-10-30 11:38:41 +000088 // Only used for testing
89 void CopyOutTo(void* memory) const override
90 {
91 switch (this->GetDataType())
92 {
93 case arm_compute::DataType::F32:
94 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
95 static_cast<float*>(memory));
96 break;
kevmay012b4d88e2019-01-24 14:05:09 +000097 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +000098 case arm_compute::DataType::QASYMM8:
99 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
100 static_cast<uint8_t*>(memory));
101 break;
102 default:
103 {
104 throw armnn::UnimplementedException();
105 }
106 }
107 }
108
109 // Only used for testing
110 void CopyInFrom(const void* memory) override
111 {
112 switch (this->GetDataType())
113 {
114 case arm_compute::DataType::F32:
115 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
116 this->GetTensor());
117 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000118 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000119 case arm_compute::DataType::QASYMM8:
120 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
121 this->GetTensor());
122 break;
123 default:
124 {
125 throw armnn::UnimplementedException();
126 }
127 }
128 }
129
telsoa014fcda012018-03-09 14:13:49 +0000130 arm_compute::Tensor m_Tensor;
telsoa01c577f2c2018-08-31 09:22:23 +0100131 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
telsoa014fcda012018-03-09 14:13:49 +0000132};
133
134class NeonSubTensorHandle : public INeonTensorHandle
135{
136public:
telsoa01c577f2c2018-08-31 09:22:23 +0100137 NeonSubTensorHandle(INeonTensorHandle* parent,
138 const arm_compute::TensorShape& shape,
139 const arm_compute::Coordinates& coords)
140 : m_Tensor(&parent->GetTensor(), shape, coords)
telsoa014fcda012018-03-09 14:13:49 +0000141 {
telsoa01c577f2c2018-08-31 09:22:23 +0100142 parentHandle = parent;
telsoa014fcda012018-03-09 14:13:49 +0000143 }
144
145 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
146 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +0100147
148 virtual void Allocate() override {}
149 virtual void Manage() override {}
telsoa014fcda012018-03-09 14:13:49 +0000150
telsoa01c577f2c2018-08-31 09:22:23 +0100151 virtual ITensorHandle* GetParent() const override { return parentHandle; }
152
telsoa014fcda012018-03-09 14:13:49 +0000153 virtual arm_compute::DataType GetDataType() const override
154 {
155 return m_Tensor.info()->data_type();
156 }
157
telsoa01c577f2c2018-08-31 09:22:23 +0100158 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
159
160 virtual const void* Map(bool /* blocking = true */) const override
161 {
162 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
163 }
164 virtual void Unmap() const override {}
165
166 TensorShape GetStrides() const override
167 {
168 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
169 }
170
171 TensorShape GetShape() const override
172 {
173 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
174 }
David Beck09e2f272018-10-30 11:38:41 +0000175
telsoa014fcda012018-03-09 14:13:49 +0000176private:
David Beck09e2f272018-10-30 11:38:41 +0000177 // Only used for testing
178 void CopyOutTo(void* memory) const override
179 {
180 switch (this->GetDataType())
181 {
182 case arm_compute::DataType::F32:
183 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
184 static_cast<float*>(memory));
185 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000186 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000187 case arm_compute::DataType::QASYMM8:
188 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
189 static_cast<uint8_t*>(memory));
190 break;
191 default:
192 {
193 throw armnn::UnimplementedException();
194 }
195 }
196 }
197
198 // Only used for testing
199 void CopyInFrom(const void* memory) override
200 {
201 switch (this->GetDataType())
202 {
203 case arm_compute::DataType::F32:
204 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
205 this->GetTensor());
206 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000207 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000208 case arm_compute::DataType::QASYMM8:
209 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
210 this->GetTensor());
211 break;
212 default:
213 {
214 throw armnn::UnimplementedException();
215 }
216 }
217 }
218
telsoa01c577f2c2018-08-31 09:22:23 +0100219 arm_compute::SubTensor m_Tensor;
220 ITensorHandle* parentHandle = nullptr;
telsoa014fcda012018-03-09 14:13:49 +0000221};
222
David Beck09e2f272018-10-30 11:38:41 +0000223} // namespace armnn