blob: 3bbba785254840f5a2978453d0922e56211efb25 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00007#include <backendsCommon/OutputHandler.hpp>
Derek Lambertic81855f2019-06-13 17:34:19 +01008#include <aclCommon/ArmComputeTensorHandle.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <aclCommon/ArmComputeTensorUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000010
telsoa01c577f2c2018-08-31 09:22:23 +010011#include <arm_compute/runtime/MemoryGroup.h>
12#include <arm_compute/runtime/IMemoryGroup.h>
telsoa014fcda012018-03-09 14:13:49 +000013#include <arm_compute/runtime/Tensor.h>
14#include <arm_compute/runtime/SubTensor.h>
15#include <arm_compute/core/TensorShape.h>
16#include <arm_compute/core/Coordinates.h>
17
telsoa01c577f2c2018-08-31 09:22:23 +010018#include <boost/polymorphic_pointer_cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000019
20namespace armnn
21{
22
Derek Lambertic81855f2019-06-13 17:34:19 +010023class NeonTensorHandle : public IAclTensorHandle
telsoa014fcda012018-03-09 14:13:49 +000024{
25public:
26 NeonTensorHandle(const TensorInfo& tensorInfo)
27 {
28 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
29 }
30
Francis Murtagh351d13d2018-09-24 15:01:18 +010031 NeonTensorHandle(const TensorInfo& tensorInfo, DataLayout dataLayout)
32 {
33 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
34 }
35
telsoa014fcda012018-03-09 14:13:49 +000036 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
37 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +010038
telsoa014fcda012018-03-09 14:13:49 +000039 virtual void Allocate() override
40 {
41 armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
42 };
43
telsoa01c577f2c2018-08-31 09:22:23 +010044 virtual void Manage() override
45 {
46 BOOST_ASSERT(m_MemoryGroup != nullptr);
47 m_MemoryGroup->manage(&m_Tensor);
48 }
49
telsoa01c577f2c2018-08-31 09:22:23 +010050 virtual ITensorHandle* GetParent() const override { return nullptr; }
51
telsoa014fcda012018-03-09 14:13:49 +000052 virtual arm_compute::DataType GetDataType() const override
53 {
54 return m_Tensor.info()->data_type();
55 }
56
telsoa01c577f2c2018-08-31 09:22:23 +010057 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
58 {
59 m_MemoryGroup = boost::polymorphic_pointer_downcast<arm_compute::MemoryGroup>(memoryGroup);
60 }
61
62 virtual const void* Map(bool /* blocking = true */) const override
63 {
64 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
65 }
66 virtual void Unmap() const override {}
67
68
69 TensorShape GetStrides() const override
70 {
71 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
72 }
73
74 TensorShape GetShape() const override
75 {
76 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
77 }
78
telsoa014fcda012018-03-09 14:13:49 +000079private:
David Beck09e2f272018-10-30 11:38:41 +000080 // Only used for testing
81 void CopyOutTo(void* memory) const override
82 {
83 switch (this->GetDataType())
84 {
85 case arm_compute::DataType::F32:
86 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
87 static_cast<float*>(memory));
88 break;
kevmay012b4d88e2019-01-24 14:05:09 +000089 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +000090 case arm_compute::DataType::QASYMM8:
91 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
92 static_cast<uint8_t*>(memory));
93 break;
94 default:
95 {
96 throw armnn::UnimplementedException();
97 }
98 }
99 }
100
101 // Only used for testing
102 void CopyInFrom(const void* memory) override
103 {
104 switch (this->GetDataType())
105 {
106 case arm_compute::DataType::F32:
107 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
108 this->GetTensor());
109 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000110 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000111 case arm_compute::DataType::QASYMM8:
112 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
113 this->GetTensor());
114 break;
115 default:
116 {
117 throw armnn::UnimplementedException();
118 }
119 }
120 }
121
telsoa014fcda012018-03-09 14:13:49 +0000122 arm_compute::Tensor m_Tensor;
telsoa01c577f2c2018-08-31 09:22:23 +0100123 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
telsoa014fcda012018-03-09 14:13:49 +0000124};
125
Derek Lambertic81855f2019-06-13 17:34:19 +0100126class NeonSubTensorHandle : public IAclTensorHandle
telsoa014fcda012018-03-09 14:13:49 +0000127{
128public:
Derek Lambertic81855f2019-06-13 17:34:19 +0100129 NeonSubTensorHandle(IAclTensorHandle* parent,
telsoa01c577f2c2018-08-31 09:22:23 +0100130 const arm_compute::TensorShape& shape,
131 const arm_compute::Coordinates& coords)
132 : m_Tensor(&parent->GetTensor(), shape, coords)
telsoa014fcda012018-03-09 14:13:49 +0000133 {
telsoa01c577f2c2018-08-31 09:22:23 +0100134 parentHandle = parent;
telsoa014fcda012018-03-09 14:13:49 +0000135 }
136
137 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
138 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +0100139
140 virtual void Allocate() override {}
141 virtual void Manage() override {}
telsoa014fcda012018-03-09 14:13:49 +0000142
telsoa01c577f2c2018-08-31 09:22:23 +0100143 virtual ITensorHandle* GetParent() const override { return parentHandle; }
144
telsoa014fcda012018-03-09 14:13:49 +0000145 virtual arm_compute::DataType GetDataType() const override
146 {
147 return m_Tensor.info()->data_type();
148 }
149
telsoa01c577f2c2018-08-31 09:22:23 +0100150 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
151
152 virtual const void* Map(bool /* blocking = true */) const override
153 {
154 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
155 }
156 virtual void Unmap() const override {}
157
158 TensorShape GetStrides() const override
159 {
160 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
161 }
162
163 TensorShape GetShape() const override
164 {
165 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
166 }
David Beck09e2f272018-10-30 11:38:41 +0000167
telsoa014fcda012018-03-09 14:13:49 +0000168private:
David Beck09e2f272018-10-30 11:38:41 +0000169 // Only used for testing
170 void CopyOutTo(void* memory) const override
171 {
172 switch (this->GetDataType())
173 {
174 case arm_compute::DataType::F32:
175 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
176 static_cast<float*>(memory));
177 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000178 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000179 case arm_compute::DataType::QASYMM8:
180 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
181 static_cast<uint8_t*>(memory));
182 break;
183 default:
184 {
185 throw armnn::UnimplementedException();
186 }
187 }
188 }
189
190 // Only used for testing
191 void CopyInFrom(const void* memory) override
192 {
193 switch (this->GetDataType())
194 {
195 case arm_compute::DataType::F32:
196 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
197 this->GetTensor());
198 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000199 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000200 case arm_compute::DataType::QASYMM8:
201 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
202 this->GetTensor());
203 break;
204 default:
205 {
206 throw armnn::UnimplementedException();
207 }
208 }
209 }
210
telsoa01c577f2c2018-08-31 09:22:23 +0100211 arm_compute::SubTensor m_Tensor;
212 ITensorHandle* parentHandle = nullptr;
telsoa014fcda012018-03-09 14:13:49 +0000213};
214
David Beck09e2f272018-10-30 11:38:41 +0000215} // namespace armnn