blob: 541beefde626469317610986acb92d8a173a89d2 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6#include "CpuTensorHandleFwd.hpp"
7
8#include "armnn/TypesUtils.hpp"
9
10#include "OutputHandler.hpp"
11
telsoa01c577f2c2018-08-31 09:22:23 +010012#include <algorithm>
13
telsoa014fcda012018-03-09 14:13:49 +000014namespace armnn
15{
16
telsoa01c577f2c2018-08-31 09:22:23 +010017// Abstract tensor handles wrapping a CPU-readable region of memory, interpreting it as tensor data.
telsoa014fcda012018-03-09 14:13:49 +000018class ConstCpuTensorHandle : public ITensorHandle
19{
20public:
21 template <typename T>
22 const T* GetConstTensor() const
23 {
24 BOOST_ASSERT(GetTensorInfo().GetDataType() == GetDataType<T>());
25 return reinterpret_cast<const T*>(m_Memory);
26 }
27
28 const TensorInfo& GetTensorInfo() const
29 {
30 return m_TensorInfo;
31 }
32
33 virtual ITensorHandle::Type GetType() const override
34 {
35 return ITensorHandle::Cpu;
36 }
37
telsoa01c577f2c2018-08-31 09:22:23 +010038 virtual void Manage() override {}
39
40 virtual ITensorHandle* GetParent() const override { return nullptr; }
41
42 virtual const void* Map(bool /* blocking = true */) const override { return m_Memory; }
43 virtual void Unmap() const override {}
44
45 TensorShape GetStrides() const override
46 {
47 TensorShape shape(m_TensorInfo.GetShape());
48 auto size = GetDataTypeSize(m_TensorInfo.GetDataType());
49 auto runningSize = size;
50 std::vector<unsigned int> strides(shape.GetNumDimensions());
51 auto lastIdx = shape.GetNumDimensions()-1;
52 for (unsigned int i=0; i < lastIdx ; i++)
53 {
54 strides[lastIdx-i] = runningSize;
55 runningSize *= shape[lastIdx-i];
56 }
57 strides[0] = runningSize;
58 return TensorShape(shape.GetNumDimensions(), strides.data());
59 }
60 TensorShape GetShape() const override { return m_TensorInfo.GetShape(); }
61
telsoa014fcda012018-03-09 14:13:49 +000062protected:
63 ConstCpuTensorHandle(const TensorInfo& tensorInfo);
64
65 void SetConstMemory(const void* mem) { m_Memory = mem; }
66
67private:
68 ConstCpuTensorHandle(const ConstCpuTensorHandle& other) = delete;
69 ConstCpuTensorHandle& operator=(const ConstCpuTensorHandle& other) = delete;
70
71 TensorInfo m_TensorInfo;
72 const void* m_Memory;
73};
74
telsoa01c577f2c2018-08-31 09:22:23 +010075// Abstract specialization of ConstCpuTensorHandle that allows write access to the same data.
telsoa014fcda012018-03-09 14:13:49 +000076class CpuTensorHandle : public ConstCpuTensorHandle
77{
78public:
79 template <typename T>
80 T* GetTensor() const
81 {
82 BOOST_ASSERT(GetTensorInfo().GetDataType() == GetDataType<T>());
83 return reinterpret_cast<T*>(m_MutableMemory);
84 }
85
86protected:
87 CpuTensorHandle(const TensorInfo& tensorInfo);
88
89 void SetMemory(void* mem)
90 {
91 m_MutableMemory = mem;
92 SetConstMemory(m_MutableMemory);
93 }
94
95private:
96
97 CpuTensorHandle(const CpuTensorHandle& other) = delete;
98 CpuTensorHandle& operator=(const CpuTensorHandle& other) = delete;
99 void* m_MutableMemory;
100};
101
102// A CpuTensorHandle that owns the wrapped memory region.
103class ScopedCpuTensorHandle : public CpuTensorHandle
104{
105public:
106 explicit ScopedCpuTensorHandle(const TensorInfo& tensorInfo);
107
telsoa01c577f2c2018-08-31 09:22:23 +0100108 // Copies contents from Tensor.
telsoa014fcda012018-03-09 14:13:49 +0000109 explicit ScopedCpuTensorHandle(const ConstTensor& tensor);
110
telsoa01c577f2c2018-08-31 09:22:23 +0100111 // Copies contents from ConstCpuTensorHandle
112 explicit ScopedCpuTensorHandle(const ConstCpuTensorHandle& tensorHandle);
113
telsoa014fcda012018-03-09 14:13:49 +0000114 ScopedCpuTensorHandle(const ScopedCpuTensorHandle& other);
115 ScopedCpuTensorHandle& operator=(const ScopedCpuTensorHandle& other);
116 ~ScopedCpuTensorHandle();
117
118 virtual void Allocate() override;
119
120private:
121 void CopyFrom(const ScopedCpuTensorHandle& other);
122 void CopyFrom(const void* srcMemory, unsigned int numBytes);
123};
124
125// A CpuTensorHandle that wraps an already allocated memory region.
126//
127// Clients must make sure the passed in memory region stays alive for the lifetime of
128// the PassthroughCpuTensorHandle instance.
129//
telsoa01c577f2c2018-08-31 09:22:23 +0100130// Note there is no polymorphism to/from ConstPassthroughCpuTensorHandle.
telsoa014fcda012018-03-09 14:13:49 +0000131class PassthroughCpuTensorHandle : public CpuTensorHandle
132{
133public:
134 PassthroughCpuTensorHandle(const TensorInfo& tensorInfo, void* mem)
135 : CpuTensorHandle(tensorInfo)
136 {
137 SetMemory(mem);
138 }
139
140 virtual void Allocate() override;
141};
142
143// A ConstCpuTensorHandle that wraps an already allocated memory region.
144//
145// This allows users to pass in const memory to a network.
146// Clients must make sure the passed in memory region stays alive for the lifetime of
147// the PassthroughCpuTensorHandle instance.
148//
telsoa01c577f2c2018-08-31 09:22:23 +0100149// Note there is no polymorphism to/from PassthroughCpuTensorHandle.
telsoa014fcda012018-03-09 14:13:49 +0000150class ConstPassthroughCpuTensorHandle : public ConstCpuTensorHandle
151{
152public:
153 ConstPassthroughCpuTensorHandle(const TensorInfo& tensorInfo, const void* mem)
154 : ConstCpuTensorHandle(tensorInfo)
155 {
156 SetConstMemory(mem);
157 }
158
159 virtual void Allocate() override;
160};
161
162
telsoa01c577f2c2018-08-31 09:22:23 +0100163// Template specializations.
telsoa014fcda012018-03-09 14:13:49 +0000164
165template <>
166const void* ConstCpuTensorHandle::GetConstTensor() const;
167
168template <>
169void* CpuTensorHandle::GetTensor() const;
170
171} // namespace armnn