blob: e3d7b5b491cd5d111b0db0acb4f32578aed997e0 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
7#include "OutputHandler.hpp"
8#include "ArmComputeTensorUtils.hpp"
9
10#include <arm_compute/runtime/CL/CLTensor.h>
11#include <arm_compute/runtime/CL/CLSubTensor.h>
telsoa01c577f2c2018-08-31 09:22:23 +010012#include <arm_compute/runtime/CL/CLMemoryGroup.h>
13#include <arm_compute/runtime/IMemoryGroup.h>
telsoa014fcda012018-03-09 14:13:49 +000014#include <arm_compute/core/TensorShape.h>
15#include <arm_compute/core/Coordinates.h>
16
telsoa01c577f2c2018-08-31 09:22:23 +010017#include <boost/polymorphic_pointer_cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000018
19namespace armnn
20{
21
22
23class IClTensorHandle : public ITensorHandle
24{
25public:
26 virtual arm_compute::ICLTensor& GetTensor() = 0;
27 virtual arm_compute::ICLTensor const& GetTensor() const = 0;
telsoa014fcda012018-03-09 14:13:49 +000028 virtual arm_compute::DataType GetDataType() const = 0;
telsoa01c577f2c2018-08-31 09:22:23 +010029 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) = 0;
telsoa014fcda012018-03-09 14:13:49 +000030};
31
32class ClTensorHandle : public IClTensorHandle
33{
34public:
35 ClTensorHandle(const TensorInfo& tensorInfo)
36 {
37 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
38 }
39
Francis Murtagh351d13d2018-09-24 15:01:18 +010040 ClTensorHandle(const TensorInfo& tensorInfo, DataLayout dataLayout)
41 {
42 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
43 }
44
telsoa014fcda012018-03-09 14:13:49 +000045 arm_compute::CLTensor& GetTensor() override { return m_Tensor; }
46 arm_compute::CLTensor const& GetTensor() const override { return m_Tensor; }
telsoa01c577f2c2018-08-31 09:22:23 +010047 virtual void Allocate() override {armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);}
telsoa014fcda012018-03-09 14:13:49 +000048
telsoa01c577f2c2018-08-31 09:22:23 +010049 virtual void Manage() override
50 {
51 assert(m_MemoryGroup != nullptr);
52 m_MemoryGroup->manage(&m_Tensor);
53 }
telsoa014fcda012018-03-09 14:13:49 +000054
telsoa01c577f2c2018-08-31 09:22:23 +010055 virtual const void* Map(bool blocking = true) const override
56 {
57 const_cast<arm_compute::CLTensor*>(&m_Tensor)->map(blocking);
58 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
59 }
60 virtual void Unmap() const override { const_cast<arm_compute::CLTensor*>(&m_Tensor)->unmap(); }
61
62 virtual ITensorHandle::Type GetType() const override { return ITensorHandle::CL; }
63
64 virtual ITensorHandle* GetParent() const override { return nullptr; }
telsoa014fcda012018-03-09 14:13:49 +000065
66 virtual arm_compute::DataType GetDataType() const override
67 {
68 return m_Tensor.info()->data_type();
69 }
70
telsoa01c577f2c2018-08-31 09:22:23 +010071 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
72 {
73 m_MemoryGroup = boost::polymorphic_pointer_downcast<arm_compute::CLMemoryGroup>(memoryGroup);
74 }
75
76 TensorShape GetStrides() const override
77 {
78 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
79 }
80
81 TensorShape GetShape() const override
82 {
83 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
84 }
telsoa014fcda012018-03-09 14:13:49 +000085private:
86 arm_compute::CLTensor m_Tensor;
telsoa01c577f2c2018-08-31 09:22:23 +010087 std::shared_ptr<arm_compute::CLMemoryGroup> m_MemoryGroup;
telsoa014fcda012018-03-09 14:13:49 +000088};
89
90class ClSubTensorHandle : public IClTensorHandle
91{
92public:
telsoa01c577f2c2018-08-31 09:22:23 +010093 ClSubTensorHandle(IClTensorHandle* parent,
94 const arm_compute::TensorShape& shape,
95 const arm_compute::Coordinates& coords)
96 : m_Tensor(&parent->GetTensor(), shape, coords)
telsoa014fcda012018-03-09 14:13:49 +000097 {
telsoa01c577f2c2018-08-31 09:22:23 +010098 parentHandle = parent;
telsoa014fcda012018-03-09 14:13:49 +000099 }
100
101 arm_compute::CLSubTensor& GetTensor() override { return m_Tensor; }
102 arm_compute::CLSubTensor const& GetTensor() const override { return m_Tensor; }
telsoa014fcda012018-03-09 14:13:49 +0000103
telsoa01c577f2c2018-08-31 09:22:23 +0100104 virtual void Allocate() override {}
105 virtual void Manage() override {}
telsoa014fcda012018-03-09 14:13:49 +0000106
telsoa01c577f2c2018-08-31 09:22:23 +0100107 virtual const void* Map(bool blocking = true) const override
108 {
109 const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->map(blocking);
110 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
111 }
112 virtual void Unmap() const override { const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->unmap(); }
113
114 virtual ITensorHandle::Type GetType() const override { return ITensorHandle::CL; }
115
116 virtual ITensorHandle* GetParent() const override { return parentHandle; }
telsoa014fcda012018-03-09 14:13:49 +0000117
118 virtual arm_compute::DataType GetDataType() const override
119 {
120 return m_Tensor.info()->data_type();
121 }
122
telsoa01c577f2c2018-08-31 09:22:23 +0100123 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
124
125 TensorShape GetStrides() const override
126 {
127 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
128 }
129
130 TensorShape GetShape() const override
131 {
132 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
133 }
134
telsoa014fcda012018-03-09 14:13:49 +0000135private:
telsoa01c577f2c2018-08-31 09:22:23 +0100136 mutable arm_compute::CLSubTensor m_Tensor;
137 ITensorHandle* parentHandle = nullptr;
telsoa014fcda012018-03-09 14:13:49 +0000138
139};
140
telsoa01c577f2c2018-08-31 09:22:23 +0100141}