blob: 7206b6fc5a2f1e7e654cf01fd1928ce4fd0a93d6 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00007#include <backendsCommon/OutputHandler.hpp>
8#include <aclCommon/ArmComputeTensorUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009
telsoa01c577f2c2018-08-31 09:22:23 +010010#include <arm_compute/runtime/MemoryGroup.h>
11#include <arm_compute/runtime/IMemoryGroup.h>
telsoa014fcda012018-03-09 14:13:49 +000012#include <arm_compute/runtime/Tensor.h>
13#include <arm_compute/runtime/SubTensor.h>
14#include <arm_compute/core/TensorShape.h>
15#include <arm_compute/core/Coordinates.h>
16
telsoa01c577f2c2018-08-31 09:22:23 +010017#include <boost/polymorphic_pointer_cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
19namespace armnn
20{
21
22class INeonTensorHandle : public ITensorHandle
23{
24public:
25 virtual arm_compute::ITensor& GetTensor() = 0;
26 virtual arm_compute::ITensor const& GetTensor() const = 0;
27 virtual arm_compute::DataType GetDataType() const = 0;
telsoa01c577f2c2018-08-31 09:22:23 +010028 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) = 0;
telsoa014fcda012018-03-09 14:13:49 +000029};
30
31class NeonTensorHandle : public INeonTensorHandle
32{
33public:
34 NeonTensorHandle(const TensorInfo& tensorInfo)
35 {
36 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
37 }
38
Francis Murtagh351d13d2018-09-24 15:01:18 +010039 NeonTensorHandle(const TensorInfo& tensorInfo, DataLayout dataLayout)
40 {
41 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
42 }
43
telsoa014fcda012018-03-09 14:13:49 +000044 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
45 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +010046
telsoa014fcda012018-03-09 14:13:49 +000047 virtual void Allocate() override
48 {
49 armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
50 };
51
telsoa01c577f2c2018-08-31 09:22:23 +010052 virtual void Manage() override
53 {
54 BOOST_ASSERT(m_MemoryGroup != nullptr);
55 m_MemoryGroup->manage(&m_Tensor);
56 }
57
telsoa01c577f2c2018-08-31 09:22:23 +010058 virtual ITensorHandle* GetParent() const override { return nullptr; }
59
telsoa014fcda012018-03-09 14:13:49 +000060 virtual arm_compute::DataType GetDataType() const override
61 {
62 return m_Tensor.info()->data_type();
63 }
64
telsoa01c577f2c2018-08-31 09:22:23 +010065 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
66 {
67 m_MemoryGroup = boost::polymorphic_pointer_downcast<arm_compute::MemoryGroup>(memoryGroup);
68 }
69
70 virtual const void* Map(bool /* blocking = true */) const override
71 {
72 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
73 }
74 virtual void Unmap() const override {}
75
76
77 TensorShape GetStrides() const override
78 {
79 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
80 }
81
82 TensorShape GetShape() const override
83 {
84 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
85 }
86
telsoa014fcda012018-03-09 14:13:49 +000087private:
David Beck09e2f272018-10-30 11:38:41 +000088 // Only used for testing
89 void CopyOutTo(void* memory) const override
90 {
91 switch (this->GetDataType())
92 {
93 case arm_compute::DataType::F32:
94 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
95 static_cast<float*>(memory));
96 break;
97 case arm_compute::DataType::QASYMM8:
98 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
99 static_cast<uint8_t*>(memory));
100 break;
101 default:
102 {
103 throw armnn::UnimplementedException();
104 }
105 }
106 }
107
108 // Only used for testing
109 void CopyInFrom(const void* memory) override
110 {
111 switch (this->GetDataType())
112 {
113 case arm_compute::DataType::F32:
114 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
115 this->GetTensor());
116 break;
117 case arm_compute::DataType::QASYMM8:
118 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
119 this->GetTensor());
120 break;
121 default:
122 {
123 throw armnn::UnimplementedException();
124 }
125 }
126 }
127
telsoa014fcda012018-03-09 14:13:49 +0000128 arm_compute::Tensor m_Tensor;
telsoa01c577f2c2018-08-31 09:22:23 +0100129 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
telsoa014fcda012018-03-09 14:13:49 +0000130};
131
132class NeonSubTensorHandle : public INeonTensorHandle
133{
134public:
telsoa01c577f2c2018-08-31 09:22:23 +0100135 NeonSubTensorHandle(INeonTensorHandle* parent,
136 const arm_compute::TensorShape& shape,
137 const arm_compute::Coordinates& coords)
138 : m_Tensor(&parent->GetTensor(), shape, coords)
telsoa014fcda012018-03-09 14:13:49 +0000139 {
telsoa01c577f2c2018-08-31 09:22:23 +0100140 parentHandle = parent;
telsoa014fcda012018-03-09 14:13:49 +0000141 }
142
143 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
144 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +0100145
146 virtual void Allocate() override {}
147 virtual void Manage() override {}
telsoa014fcda012018-03-09 14:13:49 +0000148
telsoa01c577f2c2018-08-31 09:22:23 +0100149 virtual ITensorHandle* GetParent() const override { return parentHandle; }
150
telsoa014fcda012018-03-09 14:13:49 +0000151 virtual arm_compute::DataType GetDataType() const override
152 {
153 return m_Tensor.info()->data_type();
154 }
155
telsoa01c577f2c2018-08-31 09:22:23 +0100156 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
157
158 virtual const void* Map(bool /* blocking = true */) const override
159 {
160 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
161 }
162 virtual void Unmap() const override {}
163
164 TensorShape GetStrides() const override
165 {
166 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
167 }
168
169 TensorShape GetShape() const override
170 {
171 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
172 }
David Beck09e2f272018-10-30 11:38:41 +0000173
telsoa014fcda012018-03-09 14:13:49 +0000174private:
David Beck09e2f272018-10-30 11:38:41 +0000175 // Only used for testing
176 void CopyOutTo(void* memory) const override
177 {
178 switch (this->GetDataType())
179 {
180 case arm_compute::DataType::F32:
181 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
182 static_cast<float*>(memory));
183 break;
184 case arm_compute::DataType::QASYMM8:
185 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
186 static_cast<uint8_t*>(memory));
187 break;
188 default:
189 {
190 throw armnn::UnimplementedException();
191 }
192 }
193 }
194
195 // Only used for testing
196 void CopyInFrom(const void* memory) override
197 {
198 switch (this->GetDataType())
199 {
200 case arm_compute::DataType::F32:
201 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
202 this->GetTensor());
203 break;
204 case arm_compute::DataType::QASYMM8:
205 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
206 this->GetTensor());
207 break;
208 default:
209 {
210 throw armnn::UnimplementedException();
211 }
212 }
213 }
214
telsoa01c577f2c2018-08-31 09:22:23 +0100215 arm_compute::SubTensor m_Tensor;
216 ITensorHandle* parentHandle = nullptr;
telsoa014fcda012018-03-09 14:13:49 +0000217};
218
David Beck09e2f272018-10-30 11:38:41 +0000219} // namespace armnn