blob: ae13d6c439c64f010dc6274c336e2f176c135fc7 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
David Beck0dbe0ee2018-09-24 15:59:27 +01006
telsoa014fcda012018-03-09 14:13:49 +00007#include "CpuTensorHandleFwd.hpp"
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00008
David Beck0dbe0ee2018-09-24 15:59:27 +01009#include <armnn/TypesUtils.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000010
11#include <backendsCommon/OutputHandler.hpp>
telsoa014fcda012018-03-09 14:13:49 +000012
telsoa01c577f2c2018-08-31 09:22:23 +010013#include <algorithm>
14
telsoa014fcda012018-03-09 14:13:49 +000015namespace armnn
16{
17
telsoa01c577f2c2018-08-31 09:22:23 +010018// Abstract tensor handles wrapping a CPU-readable region of memory, interpreting it as tensor data.
telsoa014fcda012018-03-09 14:13:49 +000019class ConstCpuTensorHandle : public ITensorHandle
20{
21public:
22 template <typename T>
23 const T* GetConstTensor() const
24 {
25 BOOST_ASSERT(GetTensorInfo().GetDataType() == GetDataType<T>());
26 return reinterpret_cast<const T*>(m_Memory);
27 }
28
29 const TensorInfo& GetTensorInfo() const
30 {
31 return m_TensorInfo;
32 }
33
telsoa01c577f2c2018-08-31 09:22:23 +010034 virtual void Manage() override {}
35
36 virtual ITensorHandle* GetParent() const override { return nullptr; }
37
38 virtual const void* Map(bool /* blocking = true */) const override { return m_Memory; }
39 virtual void Unmap() const override {}
40
41 TensorShape GetStrides() const override
42 {
43 TensorShape shape(m_TensorInfo.GetShape());
44 auto size = GetDataTypeSize(m_TensorInfo.GetDataType());
45 auto runningSize = size;
46 std::vector<unsigned int> strides(shape.GetNumDimensions());
47 auto lastIdx = shape.GetNumDimensions()-1;
48 for (unsigned int i=0; i < lastIdx ; i++)
49 {
50 strides[lastIdx-i] = runningSize;
51 runningSize *= shape[lastIdx-i];
52 }
53 strides[0] = runningSize;
54 return TensorShape(shape.GetNumDimensions(), strides.data());
55 }
56 TensorShape GetShape() const override { return m_TensorInfo.GetShape(); }
57
telsoa014fcda012018-03-09 14:13:49 +000058protected:
59 ConstCpuTensorHandle(const TensorInfo& tensorInfo);
60
61 void SetConstMemory(const void* mem) { m_Memory = mem; }
62
63private:
David Beck09e2f272018-10-30 11:38:41 +000064 // Only used for testing
65 void CopyOutTo(void *) const override {}
66 void CopyInFrom(const void*) override {}
67
telsoa014fcda012018-03-09 14:13:49 +000068 ConstCpuTensorHandle(const ConstCpuTensorHandle& other) = delete;
69 ConstCpuTensorHandle& operator=(const ConstCpuTensorHandle& other) = delete;
70
71 TensorInfo m_TensorInfo;
72 const void* m_Memory;
73};
74
telsoa01c577f2c2018-08-31 09:22:23 +010075// Abstract specialization of ConstCpuTensorHandle that allows write access to the same data.
telsoa014fcda012018-03-09 14:13:49 +000076class CpuTensorHandle : public ConstCpuTensorHandle
77{
78public:
79 template <typename T>
80 T* GetTensor() const
81 {
82 BOOST_ASSERT(GetTensorInfo().GetDataType() == GetDataType<T>());
83 return reinterpret_cast<T*>(m_MutableMemory);
84 }
85
86protected:
87 CpuTensorHandle(const TensorInfo& tensorInfo);
88
89 void SetMemory(void* mem)
90 {
91 m_MutableMemory = mem;
92 SetConstMemory(m_MutableMemory);
93 }
94
95private:
96
97 CpuTensorHandle(const CpuTensorHandle& other) = delete;
98 CpuTensorHandle& operator=(const CpuTensorHandle& other) = delete;
99 void* m_MutableMemory;
100};
101
102// A CpuTensorHandle that owns the wrapped memory region.
103class ScopedCpuTensorHandle : public CpuTensorHandle
104{
105public:
106 explicit ScopedCpuTensorHandle(const TensorInfo& tensorInfo);
107
telsoa01c577f2c2018-08-31 09:22:23 +0100108 // Copies contents from Tensor.
telsoa014fcda012018-03-09 14:13:49 +0000109 explicit ScopedCpuTensorHandle(const ConstTensor& tensor);
110
telsoa01c577f2c2018-08-31 09:22:23 +0100111 // Copies contents from ConstCpuTensorHandle
112 explicit ScopedCpuTensorHandle(const ConstCpuTensorHandle& tensorHandle);
113
telsoa014fcda012018-03-09 14:13:49 +0000114 ScopedCpuTensorHandle(const ScopedCpuTensorHandle& other);
115 ScopedCpuTensorHandle& operator=(const ScopedCpuTensorHandle& other);
116 ~ScopedCpuTensorHandle();
117
118 virtual void Allocate() override;
119
120private:
David Beck09e2f272018-10-30 11:38:41 +0000121 // Only used for testing
122 void CopyOutTo(void* memory) const override;
123 void CopyInFrom(const void* memory) override;
124
telsoa014fcda012018-03-09 14:13:49 +0000125 void CopyFrom(const ScopedCpuTensorHandle& other);
126 void CopyFrom(const void* srcMemory, unsigned int numBytes);
127};
128
129// A CpuTensorHandle that wraps an already allocated memory region.
130//
131// Clients must make sure the passed in memory region stays alive for the lifetime of
132// the PassthroughCpuTensorHandle instance.
133//
telsoa01c577f2c2018-08-31 09:22:23 +0100134// Note there is no polymorphism to/from ConstPassthroughCpuTensorHandle.
telsoa014fcda012018-03-09 14:13:49 +0000135class PassthroughCpuTensorHandle : public CpuTensorHandle
136{
137public:
138 PassthroughCpuTensorHandle(const TensorInfo& tensorInfo, void* mem)
139 : CpuTensorHandle(tensorInfo)
140 {
141 SetMemory(mem);
142 }
143
144 virtual void Allocate() override;
145};
146
147// A ConstCpuTensorHandle that wraps an already allocated memory region.
148//
149// This allows users to pass in const memory to a network.
150// Clients must make sure the passed in memory region stays alive for the lifetime of
151// the PassthroughCpuTensorHandle instance.
152//
telsoa01c577f2c2018-08-31 09:22:23 +0100153// Note there is no polymorphism to/from PassthroughCpuTensorHandle.
telsoa014fcda012018-03-09 14:13:49 +0000154class ConstPassthroughCpuTensorHandle : public ConstCpuTensorHandle
155{
156public:
157 ConstPassthroughCpuTensorHandle(const TensorInfo& tensorInfo, const void* mem)
158 : ConstCpuTensorHandle(tensorInfo)
159 {
160 SetConstMemory(mem);
161 }
162
163 virtual void Allocate() override;
164};
165
166
telsoa01c577f2c2018-08-31 09:22:23 +0100167// Template specializations.
telsoa014fcda012018-03-09 14:13:49 +0000168
169template <>
170const void* ConstCpuTensorHandle::GetConstTensor() const;
171
172template <>
173void* CpuTensorHandle::GetTensor() const;
174
175} // namespace armnn