blob: dd6413f2e738b422ef673bd6d24bbdac68d3c3ff [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
David Beck0dbe0ee2018-09-24 15:59:27 +01006
telsoa014fcda012018-03-09 14:13:49 +00007#include "CpuTensorHandleFwd.hpp"
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +00008#include "CompatibleTypes.hpp"
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009
David Beck0dbe0ee2018-09-24 15:59:27 +010010#include <armnn/TypesUtils.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000011
12#include <backendsCommon/OutputHandler.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
telsoa01c577f2c2018-08-31 09:22:23 +010014#include <algorithm>
15
telsoa014fcda012018-03-09 14:13:49 +000016namespace armnn
17{
18
telsoa01c577f2c2018-08-31 09:22:23 +010019// Abstract tensor handles wrapping a CPU-readable region of memory, interpreting it as tensor data.
telsoa014fcda012018-03-09 14:13:49 +000020class ConstCpuTensorHandle : public ITensorHandle
21{
22public:
23 template <typename T>
24 const T* GetConstTensor() const
25 {
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +000026 BOOST_ASSERT(CompatibleTypes<T>(GetTensorInfo().GetDataType()));
telsoa014fcda012018-03-09 14:13:49 +000027 return reinterpret_cast<const T*>(m_Memory);
28 }
29
30 const TensorInfo& GetTensorInfo() const
31 {
32 return m_TensorInfo;
33 }
34
telsoa01c577f2c2018-08-31 09:22:23 +010035 virtual void Manage() override {}
36
37 virtual ITensorHandle* GetParent() const override { return nullptr; }
38
39 virtual const void* Map(bool /* blocking = true */) const override { return m_Memory; }
40 virtual void Unmap() const override {}
41
42 TensorShape GetStrides() const override
43 {
44 TensorShape shape(m_TensorInfo.GetShape());
45 auto size = GetDataTypeSize(m_TensorInfo.GetDataType());
46 auto runningSize = size;
47 std::vector<unsigned int> strides(shape.GetNumDimensions());
48 auto lastIdx = shape.GetNumDimensions()-1;
49 for (unsigned int i=0; i < lastIdx ; i++)
50 {
51 strides[lastIdx-i] = runningSize;
52 runningSize *= shape[lastIdx-i];
53 }
54 strides[0] = runningSize;
55 return TensorShape(shape.GetNumDimensions(), strides.data());
56 }
57 TensorShape GetShape() const override { return m_TensorInfo.GetShape(); }
58
telsoa014fcda012018-03-09 14:13:49 +000059protected:
60 ConstCpuTensorHandle(const TensorInfo& tensorInfo);
61
62 void SetConstMemory(const void* mem) { m_Memory = mem; }
63
64private:
David Beck09e2f272018-10-30 11:38:41 +000065 // Only used for testing
66 void CopyOutTo(void *) const override {}
67 void CopyInFrom(const void*) override {}
68
telsoa014fcda012018-03-09 14:13:49 +000069 ConstCpuTensorHandle(const ConstCpuTensorHandle& other) = delete;
70 ConstCpuTensorHandle& operator=(const ConstCpuTensorHandle& other) = delete;
71
72 TensorInfo m_TensorInfo;
73 const void* m_Memory;
74};
75
Matteo Martincigh747ef822018-12-18 09:26:39 +000076template<>
77const void* ConstCpuTensorHandle::GetConstTensor<void>() const;
78
telsoa01c577f2c2018-08-31 09:22:23 +010079// Abstract specialization of ConstCpuTensorHandle that allows write access to the same data.
telsoa014fcda012018-03-09 14:13:49 +000080class CpuTensorHandle : public ConstCpuTensorHandle
81{
82public:
83 template <typename T>
84 T* GetTensor() const
85 {
Nattapat Chaimanowong649dd952019-01-22 16:10:44 +000086 BOOST_ASSERT(CompatibleTypes<T>(GetTensorInfo().GetDataType()));
telsoa014fcda012018-03-09 14:13:49 +000087 return reinterpret_cast<T*>(m_MutableMemory);
88 }
89
90protected:
91 CpuTensorHandle(const TensorInfo& tensorInfo);
92
93 void SetMemory(void* mem)
94 {
95 m_MutableMemory = mem;
96 SetConstMemory(m_MutableMemory);
97 }
98
99private:
100
101 CpuTensorHandle(const CpuTensorHandle& other) = delete;
102 CpuTensorHandle& operator=(const CpuTensorHandle& other) = delete;
103 void* m_MutableMemory;
104};
105
Matteo Martincigh747ef822018-12-18 09:26:39 +0000106template <>
107void* CpuTensorHandle::GetTensor<void>() const;
108
telsoa014fcda012018-03-09 14:13:49 +0000109// A CpuTensorHandle that owns the wrapped memory region.
110class ScopedCpuTensorHandle : public CpuTensorHandle
111{
112public:
113 explicit ScopedCpuTensorHandle(const TensorInfo& tensorInfo);
114
telsoa01c577f2c2018-08-31 09:22:23 +0100115 // Copies contents from Tensor.
telsoa014fcda012018-03-09 14:13:49 +0000116 explicit ScopedCpuTensorHandle(const ConstTensor& tensor);
117
telsoa01c577f2c2018-08-31 09:22:23 +0100118 // Copies contents from ConstCpuTensorHandle
119 explicit ScopedCpuTensorHandle(const ConstCpuTensorHandle& tensorHandle);
120
telsoa014fcda012018-03-09 14:13:49 +0000121 ScopedCpuTensorHandle(const ScopedCpuTensorHandle& other);
122 ScopedCpuTensorHandle& operator=(const ScopedCpuTensorHandle& other);
123 ~ScopedCpuTensorHandle();
124
125 virtual void Allocate() override;
126
127private:
David Beck09e2f272018-10-30 11:38:41 +0000128 // Only used for testing
129 void CopyOutTo(void* memory) const override;
130 void CopyInFrom(const void* memory) override;
131
telsoa014fcda012018-03-09 14:13:49 +0000132 void CopyFrom(const ScopedCpuTensorHandle& other);
133 void CopyFrom(const void* srcMemory, unsigned int numBytes);
134};
135
136// A CpuTensorHandle that wraps an already allocated memory region.
137//
138// Clients must make sure the passed in memory region stays alive for the lifetime of
139// the PassthroughCpuTensorHandle instance.
140//
telsoa01c577f2c2018-08-31 09:22:23 +0100141// Note there is no polymorphism to/from ConstPassthroughCpuTensorHandle.
telsoa014fcda012018-03-09 14:13:49 +0000142class PassthroughCpuTensorHandle : public CpuTensorHandle
143{
144public:
145 PassthroughCpuTensorHandle(const TensorInfo& tensorInfo, void* mem)
146 : CpuTensorHandle(tensorInfo)
147 {
148 SetMemory(mem);
149 }
150
151 virtual void Allocate() override;
152};
153
154// A ConstCpuTensorHandle that wraps an already allocated memory region.
155//
156// This allows users to pass in const memory to a network.
157// Clients must make sure the passed in memory region stays alive for the lifetime of
158// the PassthroughCpuTensorHandle instance.
159//
telsoa01c577f2c2018-08-31 09:22:23 +0100160// Note there is no polymorphism to/from PassthroughCpuTensorHandle.
telsoa014fcda012018-03-09 14:13:49 +0000161class ConstPassthroughCpuTensorHandle : public ConstCpuTensorHandle
162{
163public:
164 ConstPassthroughCpuTensorHandle(const TensorInfo& tensorInfo, const void* mem)
165 : ConstCpuTensorHandle(tensorInfo)
166 {
167 SetConstMemory(mem);
168 }
169
170 virtual void Allocate() override;
171};
172
173
telsoa01c577f2c2018-08-31 09:22:23 +0100174// Template specializations.
telsoa014fcda012018-03-09 14:13:49 +0000175
176template <>
177const void* ConstCpuTensorHandle::GetConstTensor() const;
178
179template <>
180void* CpuTensorHandle::GetTensor() const;
181
182} // namespace armnn