blob: 4e9d87d6eb37b4b2cba5b55a325119073fab46f9 [file] [log] [blame]
James Conroy1f58f032021-04-27 17:13:27 +01001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/backends/TensorHandleFwd.hpp>
9#include <armnn/backends/ITensorHandle.hpp>
10
11#include <armnn/TypesUtils.hpp>
12
13#include <CompatibleTypes.hpp>
14
15#include <algorithm>
16
17#include <armnn/utility/Assert.hpp>
18
19namespace armnn
20{
21
22// Get a TensorShape representing the strides (in bytes) for each dimension
23// of a tensor, assuming fully packed data with no padding
24TensorShape GetUnpaddedTensorStrides(const TensorInfo& tensorInfo);
25
26// Abstract tensor handles wrapping a readable region of memory, interpreting it as tensor data.
27class ConstTensorHandle : public ITensorHandle
28{
29public:
30 template <typename T>
31 const T* GetConstTensor() const
32 {
33 ARMNN_ASSERT(CompatibleTypes<T>(GetTensorInfo().GetDataType()));
34 return reinterpret_cast<const T*>(m_Memory);
35 }
36
37 const TensorInfo& GetTensorInfo() const
38 {
39 return m_TensorInfo;
40 }
41
42 virtual void Manage() override {}
43
44 virtual ITensorHandle* GetParent() const override { return nullptr; }
45
46 virtual const void* Map(bool /* blocking = true */) const override { return m_Memory; }
47 virtual void Unmap() const override {}
48
49 TensorShape GetStrides() const override
50 {
51 return GetUnpaddedTensorStrides(m_TensorInfo);
52 }
53 TensorShape GetShape() const override { return m_TensorInfo.GetShape(); }
54
55protected:
56 ConstTensorHandle(const TensorInfo& tensorInfo);
57
58 void SetConstMemory(const void* mem) { m_Memory = mem; }
59
60private:
61 // Only used for testing
62 void CopyOutTo(void *) const override { ARMNN_ASSERT_MSG(false, "Unimplemented"); }
63 void CopyInFrom(const void*) override { ARMNN_ASSERT_MSG(false, "Unimplemented"); }
64
65 ConstTensorHandle(const ConstTensorHandle& other) = delete;
66 ConstTensorHandle& operator=(const ConstTensorHandle& other) = delete;
67
68 TensorInfo m_TensorInfo;
69 const void* m_Memory;
70};
71
72template<>
73const void* ConstTensorHandle::GetConstTensor<void>() const;
74
75// Abstract specialization of ConstTensorHandle that allows write access to the same data.
76class TensorHandle : public ConstTensorHandle
77{
78public:
79 template <typename T>
80 T* GetTensor() const
81 {
82 ARMNN_ASSERT(CompatibleTypes<T>(GetTensorInfo().GetDataType()));
83 return reinterpret_cast<T*>(m_MutableMemory);
84 }
85
86protected:
87 TensorHandle(const TensorInfo& tensorInfo);
88
89 void SetMemory(void* mem)
90 {
91 m_MutableMemory = mem;
92 SetConstMemory(m_MutableMemory);
93 }
94
95private:
96
97 TensorHandle(const TensorHandle& other) = delete;
98 TensorHandle& operator=(const TensorHandle& other) = delete;
99 void* m_MutableMemory;
100};
101
102template <>
103void* TensorHandle::GetTensor<void>() const;
104
105// A TensorHandle that owns the wrapped memory region.
106class ScopedTensorHandle : public TensorHandle
107{
108public:
109 explicit ScopedTensorHandle(const TensorInfo& tensorInfo);
110
111 // Copies contents from Tensor.
112 explicit ScopedTensorHandle(const ConstTensor& tensor);
113
114 // Copies contents from ConstTensorHandle
115 explicit ScopedTensorHandle(const ConstTensorHandle& tensorHandle);
116
117 ScopedTensorHandle(const ScopedTensorHandle& other);
118 ScopedTensorHandle& operator=(const ScopedTensorHandle& other);
119 ~ScopedTensorHandle();
120
121 virtual void Allocate() override;
122
123private:
124 // Only used for testing
125 void CopyOutTo(void* memory) const override;
126 void CopyInFrom(const void* memory) override;
127
128 void CopyFrom(const ScopedTensorHandle& other);
129 void CopyFrom(const void* srcMemory, unsigned int numBytes);
130};
131
132// A TensorHandle that wraps an already allocated memory region.
133//
134// Clients must make sure the passed in memory region stays alive for the lifetime of
135// the PassthroughTensorHandle instance.
136//
137// Note there is no polymorphism to/from ConstPassthroughTensorHandle.
138class PassthroughTensorHandle : public TensorHandle
139{
140public:
141 PassthroughTensorHandle(const TensorInfo& tensorInfo, void* mem)
142 : TensorHandle(tensorInfo)
143 {
144 SetMemory(mem);
145 }
146
147 virtual void Allocate() override;
148};
149
150// A ConstTensorHandle that wraps an already allocated memory region.
151//
152// This allows users to pass in const memory to a network.
153// Clients must make sure the passed in memory region stays alive for the lifetime of
154// the PassthroughTensorHandle instance.
155//
156// Note there is no polymorphism to/from PassthroughTensorHandle.
157class ConstPassthroughTensorHandle : public ConstTensorHandle
158{
159public:
160 ConstPassthroughTensorHandle(const TensorInfo& tensorInfo, const void* mem)
161 : ConstTensorHandle(tensorInfo)
162 {
163 SetConstMemory(mem);
164 }
165
166 virtual void Allocate() override;
167};
168
169
170// Template specializations.
171
172template <>
173const void* ConstTensorHandle::GetConstTensor() const;
174
175template <>
176void* TensorHandle::GetTensor() const;
177
178class ManagedConstTensorHandle
179{
180
181public:
182 explicit ManagedConstTensorHandle(std::shared_ptr<ConstTensorHandle> ptr)
183 : m_Mapped(false)
184 , m_TensorHandle(std::move(ptr)) {};
185
186 /// RAII Managed resource Unmaps MemoryArea once out of scope
187 const void* Map(bool blocking = true)
188 {
189 if (m_TensorHandle)
190 {
191 auto pRet = m_TensorHandle->Map(blocking);
192 m_Mapped = true;
193 return pRet;
194 }
195 else
196 {
197 throw armnn::Exception("Attempting to Map null TensorHandle");
198 }
199
200 }
201
202 // Delete copy constructor as it's unnecessary
203 ManagedConstTensorHandle(const ConstTensorHandle& other) = delete;
204
205 // Delete copy assignment as it's unnecessary
206 ManagedConstTensorHandle& operator=(const ManagedConstTensorHandle& other) = delete;
207
208 // Delete move assignment as it's unnecessary
209 ManagedConstTensorHandle& operator=(ManagedConstTensorHandle&& other) noexcept = delete;
210
211 ~ManagedConstTensorHandle()
212 {
213 // Bias tensor handles need to be initialized empty before entering scope of if statement checking if enabled
214 if (m_TensorHandle)
215 {
216 Unmap();
217 }
218 }
219
220 void Unmap()
221 {
222 // Only unmap if mapped and TensorHandle exists.
223 if (m_Mapped && m_TensorHandle)
224 {
225 m_TensorHandle->Unmap();
226 m_Mapped = false;
227 }
228 }
229
230 const TensorInfo& GetTensorInfo() const
231 {
232 return m_TensorHandle->GetTensorInfo();
233 }
234
235 bool IsMapped() const
236 {
237 return m_Mapped;
238 }
239
240private:
241 bool m_Mapped;
242 std::shared_ptr<ConstTensorHandle> m_TensorHandle;
243};
244
245using ConstCpuTensorHandle ARMNN_DEPRECATED_MSG("ConstCpuTensorHandle is deprecated, "
246 "use ConstTensorHandle instead") = ConstTensorHandle;
247using CpuTensorHandle ARMNN_DEPRECATED_MSG("CpuTensorHandle is deprecated, "
248 "use TensorHandle instead") = TensorHandle;
249using ScopedCpuTensorHandle ARMNN_DEPRECATED_MSG("ScopedCpuTensorHandle is deprecated, "
250 "use ScopedTensorHandle instead") = ScopedTensorHandle;
251using PassthroughCpuTensorHandle ARMNN_DEPRECATED_MSG("PassthroughCpuTensorHandle is deprecated, use "
252 "PassthroughTensorHandle instead") = PassthroughTensorHandle;
253using ConstPassthroughCpuTensorHandle ARMNN_DEPRECATED_MSG("ConstPassthroughCpuTensorHandle is "
254 "deprecated, use ConstPassthroughTensorHandle "
255 "instead") = ConstPassthroughTensorHandle;
256
257} // namespace armnn