blob: 655427859b3ee6daebcdf843212cf33aead4e82e [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
David Beck0dbe0ee2018-09-24 15:59:27 +01007#include <backends/OutputHandler.hpp>
David Beck711fa312018-09-24 10:46:38 +01008#include <backends/aclCommon/ArmComputeTensorUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009
telsoa01c577f2c2018-08-31 09:22:23 +010010#include <arm_compute/runtime/MemoryGroup.h>
11#include <arm_compute/runtime/IMemoryGroup.h>
telsoa014fcda012018-03-09 14:13:49 +000012#include <arm_compute/runtime/Tensor.h>
13#include <arm_compute/runtime/SubTensor.h>
14#include <arm_compute/core/TensorShape.h>
15#include <arm_compute/core/Coordinates.h>
16
telsoa01c577f2c2018-08-31 09:22:23 +010017#include <boost/polymorphic_pointer_cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
19namespace armnn
20{
21
22class INeonTensorHandle : public ITensorHandle
23{
24public:
25 virtual arm_compute::ITensor& GetTensor() = 0;
26 virtual arm_compute::ITensor const& GetTensor() const = 0;
27 virtual arm_compute::DataType GetDataType() const = 0;
telsoa01c577f2c2018-08-31 09:22:23 +010028 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) = 0;
telsoa014fcda012018-03-09 14:13:49 +000029};
30
31class NeonTensorHandle : public INeonTensorHandle
32{
33public:
34 NeonTensorHandle(const TensorInfo& tensorInfo)
35 {
36 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
37 }
38
Francis Murtagh351d13d2018-09-24 15:01:18 +010039 NeonTensorHandle(const TensorInfo& tensorInfo, DataLayout dataLayout)
40 {
41 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
42 }
43
telsoa014fcda012018-03-09 14:13:49 +000044 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
45 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +010046
telsoa014fcda012018-03-09 14:13:49 +000047 virtual void Allocate() override
48 {
49 armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
50 };
51
telsoa01c577f2c2018-08-31 09:22:23 +010052 virtual void Manage() override
53 {
54 BOOST_ASSERT(m_MemoryGroup != nullptr);
55 m_MemoryGroup->manage(&m_Tensor);
56 }
57
telsoa014fcda012018-03-09 14:13:49 +000058 virtual ITensorHandle::Type GetType() const override { return ITensorHandle::Neon; }
59
telsoa01c577f2c2018-08-31 09:22:23 +010060 virtual ITensorHandle* GetParent() const override { return nullptr; }
61
telsoa014fcda012018-03-09 14:13:49 +000062 virtual arm_compute::DataType GetDataType() const override
63 {
64 return m_Tensor.info()->data_type();
65 }
66
telsoa01c577f2c2018-08-31 09:22:23 +010067 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
68 {
69 m_MemoryGroup = boost::polymorphic_pointer_downcast<arm_compute::MemoryGroup>(memoryGroup);
70 }
71
72 virtual const void* Map(bool /* blocking = true */) const override
73 {
74 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
75 }
76 virtual void Unmap() const override {}
77
78
79 TensorShape GetStrides() const override
80 {
81 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
82 }
83
84 TensorShape GetShape() const override
85 {
86 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
87 }
88
telsoa014fcda012018-03-09 14:13:49 +000089private:
90 arm_compute::Tensor m_Tensor;
telsoa01c577f2c2018-08-31 09:22:23 +010091 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
telsoa014fcda012018-03-09 14:13:49 +000092};
93
94class NeonSubTensorHandle : public INeonTensorHandle
95{
96public:
telsoa01c577f2c2018-08-31 09:22:23 +010097 NeonSubTensorHandle(INeonTensorHandle* parent,
98 const arm_compute::TensorShape& shape,
99 const arm_compute::Coordinates& coords)
100 : m_Tensor(&parent->GetTensor(), shape, coords)
telsoa014fcda012018-03-09 14:13:49 +0000101 {
telsoa01c577f2c2018-08-31 09:22:23 +0100102 parentHandle = parent;
telsoa014fcda012018-03-09 14:13:49 +0000103 }
104
105 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
106 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +0100107
108 virtual void Allocate() override {}
109 virtual void Manage() override {}
telsoa014fcda012018-03-09 14:13:49 +0000110
111 virtual ITensorHandle::Type GetType() const override { return ITensorHandle::Neon; }
112
telsoa01c577f2c2018-08-31 09:22:23 +0100113 virtual ITensorHandle* GetParent() const override { return parentHandle; }
114
telsoa014fcda012018-03-09 14:13:49 +0000115 virtual arm_compute::DataType GetDataType() const override
116 {
117 return m_Tensor.info()->data_type();
118 }
119
telsoa01c577f2c2018-08-31 09:22:23 +0100120 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
121
122 virtual const void* Map(bool /* blocking = true */) const override
123 {
124 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
125 }
126 virtual void Unmap() const override {}
127
128 TensorShape GetStrides() const override
129 {
130 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
131 }
132
133 TensorShape GetShape() const override
134 {
135 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
136 }
telsoa014fcda012018-03-09 14:13:49 +0000137private:
telsoa01c577f2c2018-08-31 09:22:23 +0100138 arm_compute::SubTensor m_Tensor;
139 ITensorHandle* parentHandle = nullptr;
telsoa014fcda012018-03-09 14:13:49 +0000140};
141
142}