blob: 9077f348883fe2dd75564f7ee1ff07f074854e52 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00007#include <backendsCommon/OutputHandler.hpp>
Derek Lambertic81855f2019-06-13 17:34:19 +01008#include <aclCommon/ArmComputeTensorHandle.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <aclCommon/ArmComputeTensorUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000010
telsoa01c577f2c2018-08-31 09:22:23 +010011#include <arm_compute/runtime/MemoryGroup.h>
12#include <arm_compute/runtime/IMemoryGroup.h>
telsoa014fcda012018-03-09 14:13:49 +000013#include <arm_compute/runtime/Tensor.h>
14#include <arm_compute/runtime/SubTensor.h>
15#include <arm_compute/core/TensorShape.h>
16#include <arm_compute/core/Coordinates.h>
17
telsoa01c577f2c2018-08-31 09:22:23 +010018#include <boost/polymorphic_pointer_cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000019
20namespace armnn
21{
22
Derek Lambertic81855f2019-06-13 17:34:19 +010023class NeonTensorHandle : public IAclTensorHandle
telsoa014fcda012018-03-09 14:13:49 +000024{
25public:
26 NeonTensorHandle(const TensorInfo& tensorInfo)
27 {
28 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
29 }
30
Francis Murtagh351d13d2018-09-24 15:01:18 +010031 NeonTensorHandle(const TensorInfo& tensorInfo, DataLayout dataLayout)
32 {
33 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
34 }
35
telsoa014fcda012018-03-09 14:13:49 +000036 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
37 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +010038
telsoa014fcda012018-03-09 14:13:49 +000039 virtual void Allocate() override
40 {
41 armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
42 };
43
telsoa01c577f2c2018-08-31 09:22:23 +010044 virtual void Manage() override
45 {
46 BOOST_ASSERT(m_MemoryGroup != nullptr);
47 m_MemoryGroup->manage(&m_Tensor);
48 }
49
telsoa01c577f2c2018-08-31 09:22:23 +010050 virtual ITensorHandle* GetParent() const override { return nullptr; }
51
telsoa014fcda012018-03-09 14:13:49 +000052 virtual arm_compute::DataType GetDataType() const override
53 {
54 return m_Tensor.info()->data_type();
55 }
56
telsoa01c577f2c2018-08-31 09:22:23 +010057 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
58 {
59 m_MemoryGroup = boost::polymorphic_pointer_downcast<arm_compute::MemoryGroup>(memoryGroup);
60 }
61
62 virtual const void* Map(bool /* blocking = true */) const override
63 {
64 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
65 }
66 virtual void Unmap() const override {}
67
68
69 TensorShape GetStrides() const override
70 {
71 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
72 }
73
74 TensorShape GetShape() const override
75 {
76 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
77 }
78
telsoa014fcda012018-03-09 14:13:49 +000079private:
David Beck09e2f272018-10-30 11:38:41 +000080 // Only used for testing
81 void CopyOutTo(void* memory) const override
82 {
83 switch (this->GetDataType())
84 {
85 case arm_compute::DataType::F32:
86 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
87 static_cast<float*>(memory));
88 break;
kevmay012b4d88e2019-01-24 14:05:09 +000089 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +000090 case arm_compute::DataType::QASYMM8:
91 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
92 static_cast<uint8_t*>(memory));
93 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +010094 case arm_compute::DataType::S16:
95 case arm_compute::DataType::QSYMM16:
96 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
97 static_cast<int16_t*>(memory));
98 break;
David Beck09e2f272018-10-30 11:38:41 +000099 default:
100 {
101 throw armnn::UnimplementedException();
102 }
103 }
104 }
105
106 // Only used for testing
107 void CopyInFrom(const void* memory) override
108 {
109 switch (this->GetDataType())
110 {
111 case arm_compute::DataType::F32:
112 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
113 this->GetTensor());
114 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000115 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000116 case arm_compute::DataType::QASYMM8:
117 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
118 this->GetTensor());
119 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100120 case arm_compute::DataType::S16:
121 case arm_compute::DataType::QSYMM16:
122 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
123 this->GetTensor());
124 break;
David Beck09e2f272018-10-30 11:38:41 +0000125 default:
126 {
127 throw armnn::UnimplementedException();
128 }
129 }
130 }
131
telsoa014fcda012018-03-09 14:13:49 +0000132 arm_compute::Tensor m_Tensor;
telsoa01c577f2c2018-08-31 09:22:23 +0100133 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
telsoa014fcda012018-03-09 14:13:49 +0000134};
135
Derek Lambertic81855f2019-06-13 17:34:19 +0100136class NeonSubTensorHandle : public IAclTensorHandle
telsoa014fcda012018-03-09 14:13:49 +0000137{
138public:
Derek Lambertic81855f2019-06-13 17:34:19 +0100139 NeonSubTensorHandle(IAclTensorHandle* parent,
telsoa01c577f2c2018-08-31 09:22:23 +0100140 const arm_compute::TensorShape& shape,
141 const arm_compute::Coordinates& coords)
142 : m_Tensor(&parent->GetTensor(), shape, coords)
telsoa014fcda012018-03-09 14:13:49 +0000143 {
telsoa01c577f2c2018-08-31 09:22:23 +0100144 parentHandle = parent;
telsoa014fcda012018-03-09 14:13:49 +0000145 }
146
147 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
148 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +0100149
150 virtual void Allocate() override {}
151 virtual void Manage() override {}
telsoa014fcda012018-03-09 14:13:49 +0000152
telsoa01c577f2c2018-08-31 09:22:23 +0100153 virtual ITensorHandle* GetParent() const override { return parentHandle; }
154
telsoa014fcda012018-03-09 14:13:49 +0000155 virtual arm_compute::DataType GetDataType() const override
156 {
157 return m_Tensor.info()->data_type();
158 }
159
telsoa01c577f2c2018-08-31 09:22:23 +0100160 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
161
162 virtual const void* Map(bool /* blocking = true */) const override
163 {
164 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
165 }
166 virtual void Unmap() const override {}
167
168 TensorShape GetStrides() const override
169 {
170 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
171 }
172
173 TensorShape GetShape() const override
174 {
175 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
176 }
David Beck09e2f272018-10-30 11:38:41 +0000177
telsoa014fcda012018-03-09 14:13:49 +0000178private:
David Beck09e2f272018-10-30 11:38:41 +0000179 // Only used for testing
180 void CopyOutTo(void* memory) const override
181 {
182 switch (this->GetDataType())
183 {
184 case arm_compute::DataType::F32:
185 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
186 static_cast<float*>(memory));
187 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000188 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000189 case arm_compute::DataType::QASYMM8:
190 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
191 static_cast<uint8_t*>(memory));
192 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100193 case arm_compute::DataType::S16:
194 case arm_compute::DataType::QSYMM16:
195 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
196 static_cast<int16_t*>(memory));
197 break;
David Beck09e2f272018-10-30 11:38:41 +0000198 default:
199 {
200 throw armnn::UnimplementedException();
201 }
202 }
203 }
204
205 // Only used for testing
206 void CopyInFrom(const void* memory) override
207 {
208 switch (this->GetDataType())
209 {
210 case arm_compute::DataType::F32:
211 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
212 this->GetTensor());
213 break;
kevmay012b4d88e2019-01-24 14:05:09 +0000214 case arm_compute::DataType::U8:
David Beck09e2f272018-10-30 11:38:41 +0000215 case arm_compute::DataType::QASYMM8:
216 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
217 this->GetTensor());
218 break;
Ellen Norris-Thompson29794572019-06-26 16:40:36 +0100219 case arm_compute::DataType::S16:
220 case arm_compute::DataType::QSYMM16:
221 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
222 this->GetTensor());
223 break;
David Beck09e2f272018-10-30 11:38:41 +0000224 default:
225 {
226 throw armnn::UnimplementedException();
227 }
228 }
229 }
230
telsoa01c577f2c2018-08-31 09:22:23 +0100231 arm_compute::SubTensor m_Tensor;
232 ITensorHandle* parentHandle = nullptr;
telsoa014fcda012018-03-09 14:13:49 +0000233};
234
David Beck09e2f272018-10-30 11:38:41 +0000235} // namespace armnn