blob: faa0d3838683a661ac79c5736ac2102e39c76fbb [file] [log] [blame]
David Monahan8a570462023-11-22 13:24:25 +00001//
David Monahanbd738082023-12-08 12:50:02 +00002// Copyright © 2022-2024 Arm Ltd and Contributors. All rights reserved.
David Monahan8a570462023-11-22 13:24:25 +00003// SPDX-License-Identifier: MIT
4//
5
6#include <Layer.hpp>
7
8#include "GpuFsaWorkloadFactory.hpp"
9#include "GpuFsaBackendId.hpp"
10#include "GpuFsaTensorHandle.hpp"
11
David Monahanbd738082023-12-08 12:50:02 +000012#include "workloads/GpuFsaConstantWorkload.hpp"
13#include "workloads/GpuFsaPreCompiledWorkload.hpp"
14
15#include <armnn/backends/MemCopyWorkload.hpp>
16
David Monahan8a570462023-11-22 13:24:25 +000017namespace armnn
18{
19
20namespace
21{
22static const BackendId s_Id{GpuFsaBackendId()};
23}
24template <typename QueueDescriptorType>
25std::unique_ptr<IWorkload> GpuFsaWorkloadFactory::MakeWorkload(const QueueDescriptorType& /*descriptor*/,
26 const WorkloadInfo& /*info*/) const
27{
28 return nullptr;
29}
30
31template <DataType ArmnnType>
32bool IsDataType(const WorkloadInfo& info)
33{
34 auto checkType = [](const TensorInfo& tensorInfo) {return tensorInfo.GetDataType() == ArmnnType;};
35 auto it = std::find_if(std::begin(info.m_InputTensorInfos), std::end(info.m_InputTensorInfos), checkType);
36 if (it != std::end(info.m_InputTensorInfos))
37 {
38 return true;
39 }
40 it = std::find_if(std::begin(info.m_OutputTensorInfos), std::end(info.m_OutputTensorInfos), checkType);
41 if (it != std::end(info.m_OutputTensorInfos))
42 {
43 return true;
44 }
45 return false;
46}
47
48GpuFsaWorkloadFactory::GpuFsaWorkloadFactory(const std::shared_ptr<GpuFsaMemoryManager>& memoryManager)
49 : m_MemoryManager(memoryManager)
50{
David Monahanbd738082023-12-08 12:50:02 +000051 InitializeCLCompileContext();
David Monahan8a570462023-11-22 13:24:25 +000052}
53
54GpuFsaWorkloadFactory::GpuFsaWorkloadFactory()
55 : m_MemoryManager(new GpuFsaMemoryManager())
56{
David Monahanbd738082023-12-08 12:50:02 +000057 InitializeCLCompileContext();
David Monahan8a570462023-11-22 13:24:25 +000058}
59
60const BackendId& GpuFsaWorkloadFactory::GetBackendId() const
61{
62 return s_Id;
63}
64
65bool GpuFsaWorkloadFactory::IsLayerSupported(const Layer& layer,
66 Optional<DataType> dataType,
67 std::string& outReasonIfUnsupported)
68{
69 return IWorkloadFactory::IsLayerSupported(s_Id, layer, dataType, outReasonIfUnsupported);
70}
71
72std::unique_ptr<ITensorHandle> GpuFsaWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
73 const bool /*isMemoryManaged*/) const
74{
75 std::unique_ptr<GpuFsaTensorHandle> tensorHandle = std::make_unique<GpuFsaTensorHandle>(tensorInfo);
76 tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup());
77
78 return tensorHandle;
79}
80
81std::unique_ptr<ITensorHandle> GpuFsaWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
82 DataLayout dataLayout,
83 const bool /*isMemoryManaged*/) const
84{
85 std::unique_ptr<GpuFsaTensorHandle> tensorHandle = std::make_unique<GpuFsaTensorHandle>(tensorInfo, dataLayout);
86 tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup());
87
88 return tensorHandle;
89}
90
David Monahanbd738082023-12-08 12:50:02 +000091
92void GpuFsaWorkloadFactory::InitializeCLCompileContext() {
93 // Initialize our m_CLCompileContext using default device and context
94 auto context = arm_compute::CLKernelLibrary::get().context();
95 auto device = arm_compute::CLKernelLibrary::get().get_device();
96 m_CLCompileContext = arm_compute::CLCompileContext(context, device);
97}
98
99std::unique_ptr<IWorkload> GpuFsaWorkloadFactory::CreateWorkload(LayerType type,
100 const QueueDescriptor& descriptor,
101 const WorkloadInfo& info) const
David Monahan8a570462023-11-22 13:24:25 +0000102{
David Monahanbd738082023-12-08 12:50:02 +0000103 switch(type)
104 {
105 case LayerType::Constant :
106 {
107 auto constQueueDescriptor = PolymorphicDowncast<const ConstantQueueDescriptor*>(&descriptor);
108 return std::make_unique<GpuFsaConstantWorkload>(*constQueueDescriptor, info, m_CLCompileContext);
109 }
110 case LayerType::Input :
111 {
112 auto inputQueueDescriptor = PolymorphicDowncast<const InputQueueDescriptor*>(&descriptor);
113 return std::make_unique<CopyMemGenericWorkload>(*inputQueueDescriptor, info);
114 }
115 case LayerType::Output :
116 {
117 auto outputQueueDescriptor = PolymorphicDowncast<const OutputQueueDescriptor*>(&descriptor);
118 return std::make_unique<CopyMemGenericWorkload>(*outputQueueDescriptor, info);
119 }
120 case LayerType::MemCopy :
121 {
122 auto memCopyQueueDescriptor = PolymorphicDowncast<const MemCopyQueueDescriptor*>(&descriptor);
123 if (memCopyQueueDescriptor->m_Inputs.empty() || !memCopyQueueDescriptor->m_Inputs[0])
124 {
125 throw InvalidArgumentException("GpuFsaWorkloadFactory: Invalid null input for MemCopy workload");
126 }
127 return std::make_unique<CopyMemGenericWorkload>(*memCopyQueueDescriptor, info);
128 }
129 case LayerType::PreCompiled :
130 {
131 auto precompiledQueueDescriptor = PolymorphicDowncast<const PreCompiledQueueDescriptor*>(&descriptor);
132 return std::make_unique<GpuFsaPreCompiledWorkload>(*precompiledQueueDescriptor, info);
133 }
134 default :
135 return nullptr;
136 }
David Monahan8a570462023-11-22 13:24:25 +0000137}
138
139} // namespace armnn