blob: e6b5f9e93402ecbdb44599998011bc116c2b9509 [file] [log] [blame]
Francis Murtagh9270d9e2022-08-12 13:54:17 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#include <Layer.hpp>
6#include <armnn/backends/MemCopyWorkload.hpp>
7#include <backendsCommon/MemImportWorkload.hpp>
8#include <backendsCommon/MakeWorkloadHelper.hpp>
9#include <armnn/backends/TensorHandle.hpp>
10#include "TosaRefWorkloadFactory.hpp"
11#include "TosaRefBackendId.hpp"
12#include "workloads/TosaRefWorkloads.hpp"
13#include "TosaRefTensorHandle.hpp"
14#include "TosaRefWorkloadFactory.hpp"
15
16
17namespace armnn
18{
19
20namespace
21{
22static const BackendId s_Id{TosaRefBackendId()};
23}
24template <typename F32Workload, typename U8Workload, typename QueueDescriptorType>
25std::unique_ptr<IWorkload> TosaRefWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor,
26 const WorkloadInfo& info) const
27{
28 return MakeWorkloadHelper<NullWorkload, F32Workload, U8Workload, NullWorkload, NullWorkload, NullWorkload>
29 (descriptor, info);
30}
31
32template <DataType ArmnnType>
33bool IsDataType(const WorkloadInfo& info)
34{
35 auto checkType = [](const TensorInfo& tensorInfo) {return tensorInfo.GetDataType() == ArmnnType;};
36 auto it = std::find_if(std::begin(info.m_InputTensorInfos), std::end(info.m_InputTensorInfos), checkType);
37 if (it != std::end(info.m_InputTensorInfos))
38 {
39 return true;
40 }
41 it = std::find_if(std::begin(info.m_OutputTensorInfos), std::end(info.m_OutputTensorInfos), checkType);
42 if (it != std::end(info.m_OutputTensorInfos))
43 {
44 return true;
45 }
46 return false;
47}
48
49TosaRefWorkloadFactory::TosaRefWorkloadFactory(const std::shared_ptr<TosaRefMemoryManager>& memoryManager)
50 : m_MemoryManager(memoryManager)
51{
52}
53
54TosaRefWorkloadFactory::TosaRefWorkloadFactory()
55 : m_MemoryManager(new TosaRefMemoryManager())
56{
57}
58
59const BackendId& TosaRefWorkloadFactory::GetBackendId() const
60{
61 return s_Id;
62}
63
64bool TosaRefWorkloadFactory::IsLayerSupported(const Layer& layer,
65 Optional<DataType> dataType,
66 std::string& outReasonIfUnsupported)
67{
68 return IWorkloadFactory::IsLayerSupported(s_Id, layer, dataType, outReasonIfUnsupported);
69}
70
71bool TosaRefWorkloadFactory::IsLayerSupported(const IConnectableLayer& layer,
72 Optional<DataType> dataType,
73 std::string& outReasonIfUnsupported,
74 const ModelOptions& modelOptions)
75{
76 return IWorkloadFactory::IsLayerSupported(s_Id, layer, dataType, outReasonIfUnsupported, modelOptions);
77}
78
79std::unique_ptr<ITensorHandle> TosaRefWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
80 const bool isMemoryManaged) const
81{
82 if (isMemoryManaged)
83 {
84 return std::make_unique<TosaRefTensorHandle>(tensorInfo, m_MemoryManager);
85 }
86 else
87 {
88 return std::make_unique<TosaRefTensorHandle>(tensorInfo, static_cast<unsigned int>(MemorySource::Malloc));
89 }
90}
91
92std::unique_ptr<ITensorHandle> TosaRefWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
93 DataLayout dataLayout,
94 const bool isMemoryManaged) const
95{
96 // For TosaRef it is okay to make the TensorHandle memory managed as it can also store a pointer
97 // to unmanaged memory. This also ensures memory alignment.
98 IgnoreUnused(isMemoryManaged, dataLayout);
99
100 if (isMemoryManaged)
101 {
102 return std::make_unique<TosaRefTensorHandle>(tensorInfo, m_MemoryManager);
103 }
104 else
105 {
106 return std::make_unique<TosaRefTensorHandle>(tensorInfo, static_cast<unsigned int>(MemorySource::Malloc));
107 }
108}
109
110std::unique_ptr<IWorkload> TosaRefWorkloadFactory::CreateWorkload(LayerType type,
111 const QueueDescriptor& descriptor,
112 const WorkloadInfo& info) const
113{
114 switch(type)
115 {
116 case LayerType::PreCompiled:
117 {
118 auto precompiledQueueDescriptor = PolymorphicDowncast<const PreCompiledQueueDescriptor*>(&descriptor);
119 return std::make_unique<TosaRefPreCompiledWorkload>(*precompiledQueueDescriptor, info);
120 }
121 default:
122 return nullptr;
123 }
124}
125
126} // namespace armnn