blob: 2a686fe07f743a65090a1d55aa3b8c8472548b97 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
David Beck0dbe0ee2018-09-24 15:59:27 +01006
telsoa014fcda012018-03-09 14:13:49 +00007#include "CpuTensorHandleFwd.hpp"
David Beck0dbe0ee2018-09-24 15:59:27 +01008#include <armnn/TypesUtils.hpp>
9#include <backends/OutputHandler.hpp>
telsoa014fcda012018-03-09 14:13:49 +000010
telsoa01c577f2c2018-08-31 09:22:23 +010011#include <algorithm>
12
telsoa014fcda012018-03-09 14:13:49 +000013namespace armnn
14{
15
telsoa01c577f2c2018-08-31 09:22:23 +010016// Abstract tensor handles wrapping a CPU-readable region of memory, interpreting it as tensor data.
telsoa014fcda012018-03-09 14:13:49 +000017class ConstCpuTensorHandle : public ITensorHandle
18{
19public:
20 template <typename T>
21 const T* GetConstTensor() const
22 {
23 BOOST_ASSERT(GetTensorInfo().GetDataType() == GetDataType<T>());
24 return reinterpret_cast<const T*>(m_Memory);
25 }
26
27 const TensorInfo& GetTensorInfo() const
28 {
29 return m_TensorInfo;
30 }
31
32 virtual ITensorHandle::Type GetType() const override
33 {
34 return ITensorHandle::Cpu;
35 }
36
telsoa01c577f2c2018-08-31 09:22:23 +010037 virtual void Manage() override {}
38
39 virtual ITensorHandle* GetParent() const override { return nullptr; }
40
41 virtual const void* Map(bool /* blocking = true */) const override { return m_Memory; }
42 virtual void Unmap() const override {}
43
44 TensorShape GetStrides() const override
45 {
46 TensorShape shape(m_TensorInfo.GetShape());
47 auto size = GetDataTypeSize(m_TensorInfo.GetDataType());
48 auto runningSize = size;
49 std::vector<unsigned int> strides(shape.GetNumDimensions());
50 auto lastIdx = shape.GetNumDimensions()-1;
51 for (unsigned int i=0; i < lastIdx ; i++)
52 {
53 strides[lastIdx-i] = runningSize;
54 runningSize *= shape[lastIdx-i];
55 }
56 strides[0] = runningSize;
57 return TensorShape(shape.GetNumDimensions(), strides.data());
58 }
59 TensorShape GetShape() const override { return m_TensorInfo.GetShape(); }
60
telsoa014fcda012018-03-09 14:13:49 +000061protected:
62 ConstCpuTensorHandle(const TensorInfo& tensorInfo);
63
64 void SetConstMemory(const void* mem) { m_Memory = mem; }
65
66private:
67 ConstCpuTensorHandle(const ConstCpuTensorHandle& other) = delete;
68 ConstCpuTensorHandle& operator=(const ConstCpuTensorHandle& other) = delete;
69
70 TensorInfo m_TensorInfo;
71 const void* m_Memory;
72};
73
telsoa01c577f2c2018-08-31 09:22:23 +010074// Abstract specialization of ConstCpuTensorHandle that allows write access to the same data.
telsoa014fcda012018-03-09 14:13:49 +000075class CpuTensorHandle : public ConstCpuTensorHandle
76{
77public:
78 template <typename T>
79 T* GetTensor() const
80 {
81 BOOST_ASSERT(GetTensorInfo().GetDataType() == GetDataType<T>());
82 return reinterpret_cast<T*>(m_MutableMemory);
83 }
84
85protected:
86 CpuTensorHandle(const TensorInfo& tensorInfo);
87
88 void SetMemory(void* mem)
89 {
90 m_MutableMemory = mem;
91 SetConstMemory(m_MutableMemory);
92 }
93
94private:
95
96 CpuTensorHandle(const CpuTensorHandle& other) = delete;
97 CpuTensorHandle& operator=(const CpuTensorHandle& other) = delete;
98 void* m_MutableMemory;
99};
100
101// A CpuTensorHandle that owns the wrapped memory region.
102class ScopedCpuTensorHandle : public CpuTensorHandle
103{
104public:
105 explicit ScopedCpuTensorHandle(const TensorInfo& tensorInfo);
106
telsoa01c577f2c2018-08-31 09:22:23 +0100107 // Copies contents from Tensor.
telsoa014fcda012018-03-09 14:13:49 +0000108 explicit ScopedCpuTensorHandle(const ConstTensor& tensor);
109
telsoa01c577f2c2018-08-31 09:22:23 +0100110 // Copies contents from ConstCpuTensorHandle
111 explicit ScopedCpuTensorHandle(const ConstCpuTensorHandle& tensorHandle);
112
telsoa014fcda012018-03-09 14:13:49 +0000113 ScopedCpuTensorHandle(const ScopedCpuTensorHandle& other);
114 ScopedCpuTensorHandle& operator=(const ScopedCpuTensorHandle& other);
115 ~ScopedCpuTensorHandle();
116
117 virtual void Allocate() override;
118
119private:
120 void CopyFrom(const ScopedCpuTensorHandle& other);
121 void CopyFrom(const void* srcMemory, unsigned int numBytes);
122};
123
124// A CpuTensorHandle that wraps an already allocated memory region.
125//
126// Clients must make sure the passed in memory region stays alive for the lifetime of
127// the PassthroughCpuTensorHandle instance.
128//
telsoa01c577f2c2018-08-31 09:22:23 +0100129// Note there is no polymorphism to/from ConstPassthroughCpuTensorHandle.
telsoa014fcda012018-03-09 14:13:49 +0000130class PassthroughCpuTensorHandle : public CpuTensorHandle
131{
132public:
133 PassthroughCpuTensorHandle(const TensorInfo& tensorInfo, void* mem)
134 : CpuTensorHandle(tensorInfo)
135 {
136 SetMemory(mem);
137 }
138
139 virtual void Allocate() override;
140};
141
142// A ConstCpuTensorHandle that wraps an already allocated memory region.
143//
144// This allows users to pass in const memory to a network.
145// Clients must make sure the passed in memory region stays alive for the lifetime of
146// the PassthroughCpuTensorHandle instance.
147//
telsoa01c577f2c2018-08-31 09:22:23 +0100148// Note there is no polymorphism to/from PassthroughCpuTensorHandle.
telsoa014fcda012018-03-09 14:13:49 +0000149class ConstPassthroughCpuTensorHandle : public ConstCpuTensorHandle
150{
151public:
152 ConstPassthroughCpuTensorHandle(const TensorInfo& tensorInfo, const void* mem)
153 : ConstCpuTensorHandle(tensorInfo)
154 {
155 SetConstMemory(mem);
156 }
157
158 virtual void Allocate() override;
159};
160
161
telsoa01c577f2c2018-08-31 09:22:23 +0100162// Template specializations.
telsoa014fcda012018-03-09 14:13:49 +0000163
164template <>
165const void* ConstCpuTensorHandle::GetConstTensor() const;
166
167template <>
168void* CpuTensorHandle::GetTensor() const;
169
170} // namespace armnn