blob: e6e59fcd4f6992db5f6eab92d2ad55a493d8743b [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Matteo Martincighe5b8eb92019-11-28 15:45:42 +00005
telsoa014fcda012018-03-09 14:13:49 +00006#pragma once
David Beck0dbe0ee2018-09-24 15:59:27 +01007
Matteo Martincighe5b8eb92019-11-28 15:45:42 +00008#include <armnn/backends/CpuTensorHandleFwd.hpp>
9#include <armnn/backends/ITensorHandle.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000010
David Beck0dbe0ee2018-09-24 15:59:27 +010011#include <armnn/TypesUtils.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000012
Matteo Martincighe5b8eb92019-11-28 15:45:42 +000013#include <CompatibleTypes.hpp>
telsoa014fcda012018-03-09 14:13:49 +000014
telsoa01c577f2c2018-08-31 09:22:23 +010015#include <algorithm>
16
Matteo Martincighe5b8eb92019-11-28 15:45:42 +000017#include <boost/assert.hpp>
18
telsoa014fcda012018-03-09 14:13:49 +000019namespace armnn
20{
21
Matthew Bentham4cefc412019-06-18 16:14:34 +010022// Get a TensorShape representing the strides (in bytes) for each dimension
23// of a tensor, assuming fully packed data with no padding
24TensorShape GetUnpaddedTensorStrides(const TensorInfo& tensorInfo);
25
telsoa01c577f2c2018-08-31 09:22:23 +010026// Abstract tensor handles wrapping a CPU-readable region of memory, interpreting it as tensor data.
telsoa014fcda012018-03-09 14:13:49 +000027class ConstCpuTensorHandle : public ITensorHandle
28{
29public:
30 template <typename T>
31 const T* GetConstTensor() const
32 {
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +000033 BOOST_ASSERT(CompatibleTypes<T>(GetTensorInfo().GetDataType()));
telsoa014fcda012018-03-09 14:13:49 +000034 return reinterpret_cast<const T*>(m_Memory);
35 }
36
37 const TensorInfo& GetTensorInfo() const
38 {
39 return m_TensorInfo;
40 }
41
telsoa01c577f2c2018-08-31 09:22:23 +010042 virtual void Manage() override {}
43
44 virtual ITensorHandle* GetParent() const override { return nullptr; }
45
46 virtual const void* Map(bool /* blocking = true */) const override { return m_Memory; }
47 virtual void Unmap() const override {}
48
49 TensorShape GetStrides() const override
50 {
Matthew Bentham4cefc412019-06-18 16:14:34 +010051 return GetUnpaddedTensorStrides(m_TensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +010052 }
53 TensorShape GetShape() const override { return m_TensorInfo.GetShape(); }
54
telsoa014fcda012018-03-09 14:13:49 +000055protected:
56 ConstCpuTensorHandle(const TensorInfo& tensorInfo);
57
58 void SetConstMemory(const void* mem) { m_Memory = mem; }
59
60private:
David Beck09e2f272018-10-30 11:38:41 +000061 // Only used for testing
Matthew Bentham4cefc412019-06-18 16:14:34 +010062 void CopyOutTo(void *) const override { BOOST_ASSERT_MSG(false, "Unimplemented"); }
63 void CopyInFrom(const void*) override { BOOST_ASSERT_MSG(false, "Unimplemented"); }
David Beck09e2f272018-10-30 11:38:41 +000064
telsoa014fcda012018-03-09 14:13:49 +000065 ConstCpuTensorHandle(const ConstCpuTensorHandle& other) = delete;
66 ConstCpuTensorHandle& operator=(const ConstCpuTensorHandle& other) = delete;
67
68 TensorInfo m_TensorInfo;
69 const void* m_Memory;
70};
71
Matteo Martincigh747ef822018-12-18 09:26:39 +000072template<>
73const void* ConstCpuTensorHandle::GetConstTensor<void>() const;
74
telsoa01c577f2c2018-08-31 09:22:23 +010075// Abstract specialization of ConstCpuTensorHandle that allows write access to the same data.
telsoa014fcda012018-03-09 14:13:49 +000076class CpuTensorHandle : public ConstCpuTensorHandle
77{
78public:
79 template <typename T>
80 T* GetTensor() const
81 {
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +000082 BOOST_ASSERT(CompatibleTypes<T>(GetTensorInfo().GetDataType()));
telsoa014fcda012018-03-09 14:13:49 +000083 return reinterpret_cast<T*>(m_MutableMemory);
84 }
85
86protected:
87 CpuTensorHandle(const TensorInfo& tensorInfo);
88
89 void SetMemory(void* mem)
90 {
91 m_MutableMemory = mem;
92 SetConstMemory(m_MutableMemory);
93 }
94
95private:
96
97 CpuTensorHandle(const CpuTensorHandle& other) = delete;
98 CpuTensorHandle& operator=(const CpuTensorHandle& other) = delete;
99 void* m_MutableMemory;
100};
101
Matteo Martincigh747ef822018-12-18 09:26:39 +0000102template <>
103void* CpuTensorHandle::GetTensor<void>() const;
104
telsoa014fcda012018-03-09 14:13:49 +0000105// A CpuTensorHandle that owns the wrapped memory region.
106class ScopedCpuTensorHandle : public CpuTensorHandle
107{
108public:
109 explicit ScopedCpuTensorHandle(const TensorInfo& tensorInfo);
110
telsoa01c577f2c2018-08-31 09:22:23 +0100111 // Copies contents from Tensor.
telsoa014fcda012018-03-09 14:13:49 +0000112 explicit ScopedCpuTensorHandle(const ConstTensor& tensor);
113
telsoa01c577f2c2018-08-31 09:22:23 +0100114 // Copies contents from ConstCpuTensorHandle
115 explicit ScopedCpuTensorHandle(const ConstCpuTensorHandle& tensorHandle);
116
telsoa014fcda012018-03-09 14:13:49 +0000117 ScopedCpuTensorHandle(const ScopedCpuTensorHandle& other);
118 ScopedCpuTensorHandle& operator=(const ScopedCpuTensorHandle& other);
119 ~ScopedCpuTensorHandle();
120
121 virtual void Allocate() override;
122
123private:
David Beck09e2f272018-10-30 11:38:41 +0000124 // Only used for testing
125 void CopyOutTo(void* memory) const override;
126 void CopyInFrom(const void* memory) override;
127
telsoa014fcda012018-03-09 14:13:49 +0000128 void CopyFrom(const ScopedCpuTensorHandle& other);
129 void CopyFrom(const void* srcMemory, unsigned int numBytes);
130};
131
132// A CpuTensorHandle that wraps an already allocated memory region.
133//
134// Clients must make sure the passed in memory region stays alive for the lifetime of
135// the PassthroughCpuTensorHandle instance.
136//
telsoa01c577f2c2018-08-31 09:22:23 +0100137// Note there is no polymorphism to/from ConstPassthroughCpuTensorHandle.
telsoa014fcda012018-03-09 14:13:49 +0000138class PassthroughCpuTensorHandle : public CpuTensorHandle
139{
140public:
141 PassthroughCpuTensorHandle(const TensorInfo& tensorInfo, void* mem)
142 : CpuTensorHandle(tensorInfo)
143 {
144 SetMemory(mem);
145 }
146
147 virtual void Allocate() override;
148};
149
150// A ConstCpuTensorHandle that wraps an already allocated memory region.
151//
152// This allows users to pass in const memory to a network.
153// Clients must make sure the passed in memory region stays alive for the lifetime of
154// the PassthroughCpuTensorHandle instance.
155//
telsoa01c577f2c2018-08-31 09:22:23 +0100156// Note there is no polymorphism to/from PassthroughCpuTensorHandle.
telsoa014fcda012018-03-09 14:13:49 +0000157class ConstPassthroughCpuTensorHandle : public ConstCpuTensorHandle
158{
159public:
160 ConstPassthroughCpuTensorHandle(const TensorInfo& tensorInfo, const void* mem)
161 : ConstCpuTensorHandle(tensorInfo)
162 {
163 SetConstMemory(mem);
164 }
165
166 virtual void Allocate() override;
167};
168
169
telsoa01c577f2c2018-08-31 09:22:23 +0100170// Template specializations.
telsoa014fcda012018-03-09 14:13:49 +0000171
172template <>
173const void* ConstCpuTensorHandle::GetConstTensor() const;
174
175template <>
176void* CpuTensorHandle::GetTensor() const;
177
178} // namespace armnn