blob: 9b97070766619dd1971ef06c228d03ce57643b12 [file] [log] [blame]
David Monahan8a570462023-11-22 13:24:25 +00001//
2// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include <aclCommon/BaseMemoryManager.hpp>
8
9#include <armnn/Optional.hpp>
10
11namespace armnn
12{
13
14// Dynamic Fusion workload factory.
15class GpuFsaWorkloadFactory : public IWorkloadFactory
16{
17public:
18 explicit GpuFsaWorkloadFactory(const std::shared_ptr<GpuFsaMemoryManager>& memoryManager);
19 GpuFsaWorkloadFactory();
20
21 ~GpuFsaWorkloadFactory() {}
22
23 const BackendId& GetBackendId() const override;
24
25 static bool IsLayerSupported(const Layer& layer,
26 Optional<DataType> dataType,
27 std::string& outReasonIfUnsupported);
28
29 bool SupportsSubTensors() const override { return false; }
30
31 ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateSubTensorHandle instead")
32 std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& /*parent*/,
33 TensorShape const& /*subTensorShape*/,
34 unsigned int const* /*subTensorOrigin*/) const override
35 {
36 return nullptr;
37 }
38
39 ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateTensorHandle instead")
40 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
41 const bool IsMemoryManaged = true) const override;
42
43 ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateTensorHandle instead")
44 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
45 DataLayout dataLayout,
46 const bool IsMemoryManaged = true) const override;
47
48 std::unique_ptr<IWorkload> CreateWorkload(LayerType type,
49 const QueueDescriptor& descriptor,
50 const WorkloadInfo& info) const override;
51
52private:
53 template <typename QueueDescriptorType>
54 std::unique_ptr<IWorkload> MakeWorkload(const QueueDescriptorType& descriptor, const WorkloadInfo& info) const;
55
56 mutable std::shared_ptr<GpuFsaMemoryManager> m_MemoryManager;
57};
58
59} // namespace armnn