blob: 181203481478e92a7505497b05caa03afbf25535 [file] [log] [blame]
Colm Donelanc74b1752021-03-12 15:58:48 +00001//
2// Copyright © 2021 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ClImportTensorHandleFactory.hpp"
7#include "ClTensorHandle.hpp"
8
9#include <armnn/utility/NumericCast.hpp>
10#include <armnn/utility/PolymorphicDowncast.hpp>
11
12#include <arm_compute/core/Coordinates.h>
13#include <arm_compute/runtime/CL/CLTensor.h>
14
15namespace armnn
16{
17
18using FactoryId = ITensorHandleFactory::FactoryId;
19
20std::unique_ptr<ITensorHandle> ClImportTensorHandleFactory::CreateSubTensorHandle(
21 ITensorHandle& parent, const TensorShape& subTensorShape, const unsigned int* subTensorOrigin) const
22{
23 arm_compute::Coordinates coords;
24 arm_compute::TensorShape shape = armcomputetensorutils::BuildArmComputeTensorShape(subTensorShape);
25
26 coords.set_num_dimensions(subTensorShape.GetNumDimensions());
27 for (unsigned int i = 0; i < subTensorShape.GetNumDimensions(); ++i)
28 {
29 // Arm compute indexes tensor coords in reverse order.
30 unsigned int revertedIndex = subTensorShape.GetNumDimensions() - i - 1;
31 coords.set(i, armnn::numeric_cast<int>(subTensorOrigin[revertedIndex]));
32 }
33
34 const arm_compute::TensorShape parentShape = armcomputetensorutils::BuildArmComputeTensorShape(parent.GetShape());
35
36 // In order for ACL to support subtensors the concat axis cannot be on x or y and the values of x and y
37 // must match the parent shapes
38 if (coords.x() != 0 || coords.y() != 0)
39 {
40 return nullptr;
41 }
42 if ((parentShape.x() != shape.x()) || (parentShape.y() != shape.y()))
43 {
44 return nullptr;
45 }
46
47 if (!::arm_compute::error_on_invalid_subtensor(__func__, __FILE__, __LINE__, parentShape, coords, shape))
48 {
49 return nullptr;
50 }
51
52 return std::make_unique<ClSubTensorHandle>(PolymorphicDowncast<IClTensorHandle*>(&parent), shape, coords);
53}
54
55std::unique_ptr<ITensorHandle> ClImportTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const
56{
57 return ClImportTensorHandleFactory::CreateTensorHandle(tensorInfo, false);
58}
59
60std::unique_ptr<ITensorHandle> ClImportTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
61 DataLayout dataLayout) const
62{
63 return ClImportTensorHandleFactory::CreateTensorHandle(tensorInfo, dataLayout, false);
64}
65
66std::unique_ptr<ITensorHandle> ClImportTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
67 const bool IsMemoryManaged) const
68{
69 // If IsMemoryManaged is true then throw an exception.
70 if (IsMemoryManaged)
71 {
72 throw InvalidArgumentException("ClImportTensorHandleFactory does not support memory managed tensors.");
73 }
74 std::unique_ptr<ClTensorHandle> tensorHandle = std::make_unique<ClTensorHandle>(tensorInfo);
75 tensorHandle->SetImportEnabledFlag(true);
76 tensorHandle->SetImportFlags(GetImportFlags());
77 return tensorHandle;
78}
79
80std::unique_ptr<ITensorHandle> ClImportTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
81 DataLayout dataLayout,
82 const bool IsMemoryManaged) const
83{
84 // If IsMemoryManaged is true then throw an exception.
85 if (IsMemoryManaged)
86 {
87 throw InvalidArgumentException("ClImportTensorHandleFactory does not support memory managed tensors.");
88 }
89 std::unique_ptr<ClTensorHandle> tensorHandle = std::make_unique<ClTensorHandle>(tensorInfo, dataLayout);
90 // If we are not Managing the Memory then we must be importing
91 tensorHandle->SetImportEnabledFlag(true);
92 tensorHandle->SetImportFlags(GetImportFlags());
93 return tensorHandle;
94}
95
96const FactoryId& ClImportTensorHandleFactory::GetIdStatic()
97{
98 static const FactoryId s_Id(ClImportTensorHandleFactoryId());
99 return s_Id;
100}
101
102const FactoryId& ClImportTensorHandleFactory::GetId() const
103{
104 return GetIdStatic();
105}
106
107bool ClImportTensorHandleFactory::SupportsSubTensors() const
108{
109 return true;
110}
111
112MemorySourceFlags ClImportTensorHandleFactory::GetExportFlags() const
113{
114 return m_ExportFlags;
115}
116
117MemorySourceFlags ClImportTensorHandleFactory::GetImportFlags() const
118{
119 return m_ImportFlags;
120}
121
122} // namespace armnn