blob: c1a34d24e5840ae41cca4b03d884837829aca90b [file] [log] [blame]
David Monahan8a570462023-11-22 13:24:25 +00001//
2// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "GpuFsaTensorHandle.hpp"
7#include "GpuFsaTensorHandleFactory.hpp"
8
9namespace armnn
10{
11
12using FactoryId = ITensorHandleFactory::FactoryId;
13
14std::unique_ptr<ITensorHandle> GpuFsaTensorHandleFactory::CreateSubTensorHandle(ITensorHandle& parent,
15 const TensorShape& subTensorShape,
16 const unsigned int* subTensorOrigin) const
17{
18 arm_compute::Coordinates coords;
19 arm_compute::TensorShape shape = armcomputetensorutils::BuildArmComputeTensorShape(subTensorShape);
20
21 coords.set_num_dimensions(subTensorShape.GetNumDimensions());
22 for (unsigned int i = 0; i < subTensorShape.GetNumDimensions(); ++i)
23 {
24 // Arm compute indexes tensor coords in reverse order.
25 unsigned int revertedIndex = subTensorShape.GetNumDimensions() - i - 1;
26 coords.set(i, armnn::numeric_cast<int>(subTensorOrigin[revertedIndex]));
27 }
28
29 const arm_compute::TensorShape parentShape = armcomputetensorutils::BuildArmComputeTensorShape(parent.GetShape());
30
31 // In order for ACL to support subtensors the concat axis cannot be on x or y and the values of x and y
32 // must match the parent shapes
33 if (coords.x() != 0 || coords.y() != 0)
34 {
35 return nullptr;
36 }
37 if ((parentShape.x() != shape.x()) || (parentShape.y() != shape.y()))
38 {
39 return nullptr;
40 }
41
42 if (!::arm_compute::error_on_invalid_subtensor(__func__, __FILE__, __LINE__, parentShape, coords, shape))
43 {
44 return nullptr;
45 }
46
47 return std::make_unique<GpuFsaSubTensorHandle>(PolymorphicDowncast<IClTensorHandle*>(&parent), shape, coords);
48}
49
50std::unique_ptr<ITensorHandle> GpuFsaTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const
51{
52 return GpuFsaTensorHandleFactory::CreateTensorHandle(tensorInfo, true);
53}
54
55std::unique_ptr<ITensorHandle> GpuFsaTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
56 DataLayout dataLayout) const
57{
58 return GpuFsaTensorHandleFactory::CreateTensorHandle(tensorInfo, dataLayout, true);
59}
60
61std::unique_ptr<ITensorHandle> GpuFsaTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
62 const bool IsMemoryManaged) const
63{
64 std::unique_ptr<GpuFsaTensorHandle> tensorHandle = std::make_unique<GpuFsaTensorHandle>(tensorInfo);
65 if (!IsMemoryManaged)
66 {
67 ARMNN_LOG(warning) << "GpuFsaTensorHandleFactory only has support for memory managed.";
68 }
69 tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup());
70 return tensorHandle;
71}
72
73std::unique_ptr<ITensorHandle> GpuFsaTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
74 DataLayout dataLayout,
75 const bool IsMemoryManaged) const
76{
77 std::unique_ptr<GpuFsaTensorHandle> tensorHandle = std::make_unique<GpuFsaTensorHandle>(tensorInfo, dataLayout);
78 if (!IsMemoryManaged)
79 {
80 ARMNN_LOG(warning) << "GpuFsaTensorHandleFactory only has support for memory managed.";
81 }
82 tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup());
83 return tensorHandle;
84}
85
86const FactoryId& GpuFsaTensorHandleFactory::GetIdStatic()
87{
88 static const FactoryId s_Id(GpuFsaTensorHandleFactoryId());
89 return s_Id;
90}
91
92const FactoryId& GpuFsaTensorHandleFactory::GetId() const
93{
94 return GetIdStatic();
95}
96
97bool GpuFsaTensorHandleFactory::SupportsSubTensors() const
98{
99 return true;
100}
101
102MemorySourceFlags GpuFsaTensorHandleFactory::GetExportFlags() const
103{
104 return MemorySourceFlags(MemorySource::Undefined);
105}
106
107MemorySourceFlags GpuFsaTensorHandleFactory::GetImportFlags() const
108{
109 return MemorySourceFlags(MemorySource::Undefined);
110}
111
112} // namespace armnn