blob: 8296b8315c7c74cde39945ccd9d2092d216f3b28 [file] [log] [blame]
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "NeonTensorHandleFactory.hpp"
7#include "NeonTensorHandle.hpp"
8
9#include <boost/core/ignore_unused.hpp>
10
11namespace armnn
12{
13
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010014using FactoryId = ITensorHandleFactory::FactoryId;
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010015
16std::unique_ptr<ITensorHandle> NeonTensorHandleFactory::CreateSubTensorHandle(ITensorHandle& parent,
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010017 const TensorShape& subTensorShape,
18 const unsigned int* subTensorOrigin)
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010019 const
20{
21 const arm_compute::TensorShape shape = armcomputetensorutils::BuildArmComputeTensorShape(subTensorShape);
22
23 arm_compute::Coordinates coords;
24 coords.set_num_dimensions(subTensorShape.GetNumDimensions());
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010025 for (unsigned int i = 0; i < subTensorShape.GetNumDimensions(); ++i)
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010026 {
27 // Arm compute indexes tensor coords in reverse order.
28 unsigned int revertedIndex = subTensorShape.GetNumDimensions() - i - 1;
29 coords.set(i, boost::numeric_cast<int>(subTensorOrigin[revertedIndex]));
30 }
31
32 const arm_compute::TensorShape parentShape = armcomputetensorutils::BuildArmComputeTensorShape(parent.GetShape());
33 if (!::arm_compute::error_on_invalid_subtensor(__func__, __FILE__, __LINE__, parentShape, coords, shape))
34 {
35 return nullptr;
36 }
37
38 return std::make_unique<NeonSubTensorHandle>(
39 boost::polymorphic_downcast<IAclTensorHandle*>(&parent), shape, coords);
40}
41
David Monahan3fb7e102019-08-20 11:25:29 +010042std::unique_ptr<ITensorHandle> NeonTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
43 const bool IsMemoryManaged) const
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010044{
45 auto tensorHandle = std::make_unique<NeonTensorHandle>(tensorInfo);
David Monahan3fb7e102019-08-20 11:25:29 +010046 if (IsMemoryManaged)
47 {
48 tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup());
49 }
50 // If we are not Managing the Memory then we must be importing
51 tensorHandle->SetImportEnabledFlag(!IsMemoryManaged);
52 tensorHandle->SetImportFlags(m_ImportFlags);
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010053
54 return tensorHandle;
55}
56
57std::unique_ptr<ITensorHandle> NeonTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
David Monahan3fb7e102019-08-20 11:25:29 +010058 DataLayout dataLayout,
59 const bool IsMemoryManaged) const
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010060{
61 auto tensorHandle = std::make_unique<NeonTensorHandle>(tensorInfo, dataLayout);
David Monahan3fb7e102019-08-20 11:25:29 +010062 if (IsMemoryManaged)
63 {
64 tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup());
65 }
66 // If we are not Managing the Memory then we must be importing
67 tensorHandle->SetImportEnabledFlag(!IsMemoryManaged);
68 tensorHandle->SetImportFlags(m_ImportFlags);
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010069
70 return tensorHandle;
71}
72
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010073const FactoryId& NeonTensorHandleFactory::GetIdStatic()
74{
75 static const FactoryId s_Id(NeonTensorHandleFactoryId());
76 return s_Id;
77}
78
Ferran Balaguerbfeb2712019-08-07 15:14:56 +010079const FactoryId& NeonTensorHandleFactory::GetId() const
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010080{
Jan Eilerse9f0f0f2019-08-16 10:28:37 +010081 return GetIdStatic();
Narumol Prangnawarat4e3e8182019-08-14 12:25:50 +010082}
83
84bool NeonTensorHandleFactory::SupportsSubTensors() const
85{
86 return true;
87}
88
89MemorySourceFlags NeonTensorHandleFactory::GetExportFlags() const
90{
91 return m_ExportFlags;
92}
93
94MemorySourceFlags NeonTensorHandleFactory::GetImportFlags() const
95{
96 return m_ImportFlags;
97}
98
Ferran Balaguerbfeb2712019-08-07 15:14:56 +010099} // namespace armnn