blob: 1a5c1f2ef5c17a3f2945cb86fe97c22c235c8b0c [file] [log] [blame]
Francis Murtaghbf354142022-08-12 13:54:17 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include "TosaRefMemoryManager.hpp"
8
9#include <armnn/Optional.hpp>
10#include <armnn/backends/WorkloadFactory.hpp>
11#include <armnn/utility/IgnoreUnused.hpp>
12
13
14namespace armnn
15{
16
17// Reference workload factory.
18class TosaRefWorkloadFactory : public IWorkloadFactory
19{
20public:
21 explicit TosaRefWorkloadFactory(const std::shared_ptr<TosaRefMemoryManager>& memoryManager);
22 TosaRefWorkloadFactory();
23
24 ~TosaRefWorkloadFactory() {}
25
26 const BackendId& GetBackendId() const override;
27
28 static bool IsLayerSupported(const Layer& layer,
29 Optional<DataType> dataType,
30 std::string& outReasonIfUnsupported);
31
32 static bool IsLayerSupported(const IConnectableLayer& layer,
33 Optional<DataType> dataType,
34 std::string& outReasonIfUnsupported,
35 const ModelOptions& modelOptions);
36
37 bool SupportsSubTensors() const override { return false; }
38
39 ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateSubTensorHandle instead")
40 std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent,
41 TensorShape const& subTensorShape,
42 unsigned int const* subTensorOrigin) const override
43 {
44 IgnoreUnused(parent, subTensorShape, subTensorOrigin);
45 return nullptr;
46 }
47
48 ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateTensorHandle instead")
49 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
50 const bool IsMemoryManaged = true) const override;
51
52 ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateTensorHandle instead")
53 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
54 DataLayout dataLayout,
55 const bool IsMemoryManaged = true) const override;
56
57 std::unique_ptr<IWorkload> CreateWorkload(LayerType type,
58 const QueueDescriptor& descriptor,
59 const WorkloadInfo& info) const override;
60
61private:
62 template <typename F32Workload, typename U8Workload, typename QueueDescriptorType>
63 std::unique_ptr<IWorkload> MakeWorkload(const QueueDescriptorType& descriptor, const WorkloadInfo& info) const;
64
65 mutable std::shared_ptr<TosaRefMemoryManager> m_MemoryManager;
66};
67
68} // namespace armnn