blob: 5fefc125c186ac0dd1ee459c8f47c32576549f46 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
David Beck0dbe0ee2018-09-24 15:59:27 +01006
telsoa014fcda012018-03-09 14:13:49 +00007#include "CpuTensorHandleFwd.hpp"
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +00008#include "CompatibleTypes.hpp"
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009
David Beck0dbe0ee2018-09-24 15:59:27 +010010#include <armnn/TypesUtils.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000011
12#include <backendsCommon/OutputHandler.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
telsoa01c577f2c2018-08-31 09:22:23 +010014#include <algorithm>
15
telsoa014fcda012018-03-09 14:13:49 +000016namespace armnn
17{
18
Matthew Bentham4cefc412019-06-18 16:14:34 +010019// Get a TensorShape representing the strides (in bytes) for each dimension
20// of a tensor, assuming fully packed data with no padding
21TensorShape GetUnpaddedTensorStrides(const TensorInfo& tensorInfo);
22
telsoa01c577f2c2018-08-31 09:22:23 +010023// Abstract tensor handles wrapping a CPU-readable region of memory, interpreting it as tensor data.
telsoa014fcda012018-03-09 14:13:49 +000024class ConstCpuTensorHandle : public ITensorHandle
25{
26public:
27 template <typename T>
28 const T* GetConstTensor() const
29 {
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +000030 BOOST_ASSERT(CompatibleTypes<T>(GetTensorInfo().GetDataType()));
telsoa014fcda012018-03-09 14:13:49 +000031 return reinterpret_cast<const T*>(m_Memory);
32 }
33
34 const TensorInfo& GetTensorInfo() const
35 {
36 return m_TensorInfo;
37 }
38
telsoa01c577f2c2018-08-31 09:22:23 +010039 virtual void Manage() override {}
40
41 virtual ITensorHandle* GetParent() const override { return nullptr; }
42
43 virtual const void* Map(bool /* blocking = true */) const override { return m_Memory; }
44 virtual void Unmap() const override {}
45
46 TensorShape GetStrides() const override
47 {
Matthew Bentham4cefc412019-06-18 16:14:34 +010048 return GetUnpaddedTensorStrides(m_TensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +010049 }
50 TensorShape GetShape() const override { return m_TensorInfo.GetShape(); }
51
telsoa014fcda012018-03-09 14:13:49 +000052protected:
53 ConstCpuTensorHandle(const TensorInfo& tensorInfo);
54
55 void SetConstMemory(const void* mem) { m_Memory = mem; }
56
57private:
David Beck09e2f272018-10-30 11:38:41 +000058 // Only used for testing
Matthew Bentham4cefc412019-06-18 16:14:34 +010059 void CopyOutTo(void *) const override { BOOST_ASSERT_MSG(false, "Unimplemented"); }
60 void CopyInFrom(const void*) override { BOOST_ASSERT_MSG(false, "Unimplemented"); }
David Beck09e2f272018-10-30 11:38:41 +000061
telsoa014fcda012018-03-09 14:13:49 +000062 ConstCpuTensorHandle(const ConstCpuTensorHandle& other) = delete;
63 ConstCpuTensorHandle& operator=(const ConstCpuTensorHandle& other) = delete;
64
65 TensorInfo m_TensorInfo;
66 const void* m_Memory;
67};
68
Matteo Martincigh747ef822018-12-18 09:26:39 +000069template<>
70const void* ConstCpuTensorHandle::GetConstTensor<void>() const;
71
telsoa01c577f2c2018-08-31 09:22:23 +010072// Abstract specialization of ConstCpuTensorHandle that allows write access to the same data.
telsoa014fcda012018-03-09 14:13:49 +000073class CpuTensorHandle : public ConstCpuTensorHandle
74{
75public:
76 template <typename T>
77 T* GetTensor() const
78 {
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +000079 BOOST_ASSERT(CompatibleTypes<T>(GetTensorInfo().GetDataType()));
telsoa014fcda012018-03-09 14:13:49 +000080 return reinterpret_cast<T*>(m_MutableMemory);
81 }
82
83protected:
84 CpuTensorHandle(const TensorInfo& tensorInfo);
85
86 void SetMemory(void* mem)
87 {
88 m_MutableMemory = mem;
89 SetConstMemory(m_MutableMemory);
90 }
91
92private:
93
94 CpuTensorHandle(const CpuTensorHandle& other) = delete;
95 CpuTensorHandle& operator=(const CpuTensorHandle& other) = delete;
96 void* m_MutableMemory;
97};
98
Matteo Martincigh747ef822018-12-18 09:26:39 +000099template <>
100void* CpuTensorHandle::GetTensor<void>() const;
101
telsoa014fcda012018-03-09 14:13:49 +0000102// A CpuTensorHandle that owns the wrapped memory region.
103class ScopedCpuTensorHandle : public CpuTensorHandle
104{
105public:
106 explicit ScopedCpuTensorHandle(const TensorInfo& tensorInfo);
107
telsoa01c577f2c2018-08-31 09:22:23 +0100108 // Copies contents from Tensor.
telsoa014fcda012018-03-09 14:13:49 +0000109 explicit ScopedCpuTensorHandle(const ConstTensor& tensor);
110
telsoa01c577f2c2018-08-31 09:22:23 +0100111 // Copies contents from ConstCpuTensorHandle
112 explicit ScopedCpuTensorHandle(const ConstCpuTensorHandle& tensorHandle);
113
telsoa014fcda012018-03-09 14:13:49 +0000114 ScopedCpuTensorHandle(const ScopedCpuTensorHandle& other);
115 ScopedCpuTensorHandle& operator=(const ScopedCpuTensorHandle& other);
116 ~ScopedCpuTensorHandle();
117
118 virtual void Allocate() override;
119
120private:
David Beck09e2f272018-10-30 11:38:41 +0000121 // Only used for testing
122 void CopyOutTo(void* memory) const override;
123 void CopyInFrom(const void* memory) override;
124
telsoa014fcda012018-03-09 14:13:49 +0000125 void CopyFrom(const ScopedCpuTensorHandle& other);
126 void CopyFrom(const void* srcMemory, unsigned int numBytes);
127};
128
129// A CpuTensorHandle that wraps an already allocated memory region.
130//
131// Clients must make sure the passed in memory region stays alive for the lifetime of
132// the PassthroughCpuTensorHandle instance.
133//
telsoa01c577f2c2018-08-31 09:22:23 +0100134// Note there is no polymorphism to/from ConstPassthroughCpuTensorHandle.
telsoa014fcda012018-03-09 14:13:49 +0000135class PassthroughCpuTensorHandle : public CpuTensorHandle
136{
137public:
138 PassthroughCpuTensorHandle(const TensorInfo& tensorInfo, void* mem)
139 : CpuTensorHandle(tensorInfo)
140 {
141 SetMemory(mem);
142 }
143
144 virtual void Allocate() override;
145};
146
147// A ConstCpuTensorHandle that wraps an already allocated memory region.
148//
149// This allows users to pass in const memory to a network.
150// Clients must make sure the passed in memory region stays alive for the lifetime of
151// the PassthroughCpuTensorHandle instance.
152//
telsoa01c577f2c2018-08-31 09:22:23 +0100153// Note there is no polymorphism to/from PassthroughCpuTensorHandle.
telsoa014fcda012018-03-09 14:13:49 +0000154class ConstPassthroughCpuTensorHandle : public ConstCpuTensorHandle
155{
156public:
157 ConstPassthroughCpuTensorHandle(const TensorInfo& tensorInfo, const void* mem)
158 : ConstCpuTensorHandle(tensorInfo)
159 {
160 SetConstMemory(mem);
161 }
162
163 virtual void Allocate() override;
164};
165
166
telsoa01c577f2c2018-08-31 09:22:23 +0100167// Template specializations.
telsoa014fcda012018-03-09 14:13:49 +0000168
169template <>
170const void* ConstCpuTensorHandle::GetConstTensor() const;
171
172template <>
173void* CpuTensorHandle::GetTensor() const;
174
175} // namespace armnn