blob: d4660d6de3c40d2d4a1b0319ce5307b7aea2115e [file] [log] [blame]
James Conroy1f58f032021-04-27 17:13:27 +01001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#include <armnn/Exceptions.hpp>
6#include <armnn/utility/IgnoreUnused.hpp>
7
8#include <backendsCommon/TensorHandle.hpp>
9
10#include <cstring>
11
12namespace armnn
13{
14
15TensorShape GetUnpaddedTensorStrides(const TensorInfo& tensorInfo)
16{
17 TensorShape shape(tensorInfo.GetShape());
18 auto size = GetDataTypeSize(tensorInfo.GetDataType());
19 auto runningSize = size;
20 std::vector<unsigned int> strides(shape.GetNumDimensions());
21 auto lastIdx = shape.GetNumDimensions()-1;
22 for (unsigned int i=0; i < lastIdx ; i++)
23 {
24 strides[lastIdx-i] = runningSize;
25 runningSize *= shape[lastIdx-i];
26 }
27 strides[0] = runningSize;
28 return TensorShape(shape.GetNumDimensions(), strides.data());
29}
30
31ConstTensorHandle::ConstTensorHandle(const TensorInfo& tensorInfo)
32: m_TensorInfo(tensorInfo)
33, m_Memory(nullptr)
34{
35}
36
37template <>
38const void* ConstTensorHandle::GetConstTensor<void>() const
39{
40 return m_Memory;
41}
42
43TensorHandle::TensorHandle(const TensorInfo& tensorInfo)
44: ConstTensorHandle(tensorInfo)
45, m_MutableMemory(nullptr)
46{
47}
48
49template <>
50void* TensorHandle::GetTensor<void>() const
51{
52 return m_MutableMemory;
53}
54
55ScopedTensorHandle::ScopedTensorHandle(const TensorInfo& tensorInfo)
56: TensorHandle(tensorInfo)
57{
58}
59
60ScopedTensorHandle::ScopedTensorHandle(const ConstTensor& tensor)
61: ScopedTensorHandle(tensor.GetInfo())
62{
63 CopyFrom(tensor.GetMemoryArea(), tensor.GetNumBytes());
64}
65
66ScopedTensorHandle::ScopedTensorHandle(const ConstTensorHandle& tensorHandle)
67: ScopedTensorHandle(tensorHandle.GetTensorInfo())
68{
69 CopyFrom(tensorHandle.GetConstTensor<void>(), tensorHandle.GetTensorInfo().GetNumBytes());
70}
71
72ScopedTensorHandle::ScopedTensorHandle(const ScopedTensorHandle& other)
73: TensorHandle(other.GetTensorInfo())
74{
75 CopyFrom(other);
76}
77
78ScopedTensorHandle& ScopedTensorHandle::operator=(const ScopedTensorHandle& other)
79{
80 ::operator delete(GetTensor<void>());
81 SetMemory(nullptr);
82 CopyFrom(other);
83 return *this;
84}
85
86ScopedTensorHandle::~ScopedTensorHandle()
87{
88 ::operator delete(GetTensor<void>());
89}
90
91void ScopedTensorHandle::Allocate()
92{
93 if (GetTensor<void>() == nullptr)
94 {
95 SetMemory(::operator new(GetTensorInfo().GetNumBytes()));
96 }
97 else
98 {
99 throw InvalidArgumentException("TensorHandle::Allocate Trying to allocate a TensorHandle"
100 "that already has allocated memory.");
101 }
102}
103
104void ScopedTensorHandle::CopyOutTo(void* memory) const
105{
106 memcpy(memory, GetTensor<void>(), GetTensorInfo().GetNumBytes());
107}
108
109void ScopedTensorHandle::CopyInFrom(const void* memory)
110{
111 memcpy(GetTensor<void>(), memory, GetTensorInfo().GetNumBytes());
112}
113
114void ScopedTensorHandle::CopyFrom(const ScopedTensorHandle& other)
115{
116 CopyFrom(other.GetTensor<void>(), other.GetTensorInfo().GetNumBytes());
117}
118
119void ScopedTensorHandle::CopyFrom(const void* srcMemory, unsigned int numBytes)
120{
121 ARMNN_ASSERT(GetTensor<void>() == nullptr);
122 ARMNN_ASSERT(GetTensorInfo().GetNumBytes() == numBytes);
123
124 if (srcMemory)
125 {
126 Allocate();
127 memcpy(GetTensor<void>(), srcMemory, numBytes);
128 }
129}
130
131void PassthroughTensorHandle::Allocate()
132{
133 throw InvalidArgumentException("PassthroughTensorHandle::Allocate() should never be called");
134}
135
136void ConstPassthroughTensorHandle::Allocate()
137{
138 throw InvalidArgumentException("ConstPassthroughTensorHandle::Allocate() should never be called");
139}
140
141} // namespace armnn