blob: 6d13879f517479939dc586787936ad92c33777ea [file] [log] [blame]
David Monahan8a570462023-11-22 13:24:25 +00001//
2// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <Layer.hpp>
7
8#include "GpuFsaWorkloadFactory.hpp"
9#include "GpuFsaBackendId.hpp"
10#include "GpuFsaTensorHandle.hpp"
11
12namespace armnn
13{
14
15namespace
16{
17static const BackendId s_Id{GpuFsaBackendId()};
18}
19template <typename QueueDescriptorType>
20std::unique_ptr<IWorkload> GpuFsaWorkloadFactory::MakeWorkload(const QueueDescriptorType& /*descriptor*/,
21 const WorkloadInfo& /*info*/) const
22{
23 return nullptr;
24}
25
26template <DataType ArmnnType>
27bool IsDataType(const WorkloadInfo& info)
28{
29 auto checkType = [](const TensorInfo& tensorInfo) {return tensorInfo.GetDataType() == ArmnnType;};
30 auto it = std::find_if(std::begin(info.m_InputTensorInfos), std::end(info.m_InputTensorInfos), checkType);
31 if (it != std::end(info.m_InputTensorInfos))
32 {
33 return true;
34 }
35 it = std::find_if(std::begin(info.m_OutputTensorInfos), std::end(info.m_OutputTensorInfos), checkType);
36 if (it != std::end(info.m_OutputTensorInfos))
37 {
38 return true;
39 }
40 return false;
41}
42
43GpuFsaWorkloadFactory::GpuFsaWorkloadFactory(const std::shared_ptr<GpuFsaMemoryManager>& memoryManager)
44 : m_MemoryManager(memoryManager)
45{
46}
47
48GpuFsaWorkloadFactory::GpuFsaWorkloadFactory()
49 : m_MemoryManager(new GpuFsaMemoryManager())
50{
51}
52
53const BackendId& GpuFsaWorkloadFactory::GetBackendId() const
54{
55 return s_Id;
56}
57
58bool GpuFsaWorkloadFactory::IsLayerSupported(const Layer& layer,
59 Optional<DataType> dataType,
60 std::string& outReasonIfUnsupported)
61{
62 return IWorkloadFactory::IsLayerSupported(s_Id, layer, dataType, outReasonIfUnsupported);
63}
64
65std::unique_ptr<ITensorHandle> GpuFsaWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
66 const bool /*isMemoryManaged*/) const
67{
68 std::unique_ptr<GpuFsaTensorHandle> tensorHandle = std::make_unique<GpuFsaTensorHandle>(tensorInfo);
69 tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup());
70
71 return tensorHandle;
72}
73
74std::unique_ptr<ITensorHandle> GpuFsaWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
75 DataLayout dataLayout,
76 const bool /*isMemoryManaged*/) const
77{
78 std::unique_ptr<GpuFsaTensorHandle> tensorHandle = std::make_unique<GpuFsaTensorHandle>(tensorInfo, dataLayout);
79 tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup());
80
81 return tensorHandle;
82}
83
84std::unique_ptr<IWorkload> GpuFsaWorkloadFactory::CreateWorkload(LayerType /*type*/,
85 const QueueDescriptor& /*descriptor*/,
86 const WorkloadInfo& /*info*/) const
87{
88 return nullptr;
89}
90
91} // namespace armnn