blob: ffe7c8bc97c30c9a247d8d525f1dcdf2b22c6617 [file] [log] [blame]
Colm Donelan17948b52022-02-01 23:37:04 +00001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/backends/ITensorHandleFactory.hpp>
9#include <armnnTestUtils/MockMemoryManager.hpp>
10
11namespace armnn
12{
13
14constexpr const char* MockTensorHandleFactoryId()
15{
16 return "Arm/Mock/TensorHandleFactory";
17}
18
19class MockTensorHandleFactory : public ITensorHandleFactory
20{
21
22public:
23 explicit MockTensorHandleFactory(std::shared_ptr<MockMemoryManager> mgr)
24 : m_MemoryManager(mgr)
25 , m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc))
26 , m_ExportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc))
27 {}
28
29 std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent,
30 TensorShape const& subTensorShape,
31 unsigned int const* subTensorOrigin) const override;
32
33 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override;
34
35 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
36 DataLayout dataLayout) const override;
37
38 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
39 const bool IsMemoryManaged) const override;
40
41 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
42 DataLayout dataLayout,
43 const bool IsMemoryManaged) const override;
44
45 static const FactoryId& GetIdStatic();
46
47 const FactoryId& GetId() const override;
48
49 bool SupportsSubTensors() const override;
50
51 MemorySourceFlags GetExportFlags() const override;
52
53 MemorySourceFlags GetImportFlags() const override;
54
55private:
56 mutable std::shared_ptr<MockMemoryManager> m_MemoryManager;
57 MemorySourceFlags m_ImportFlags;
58 MemorySourceFlags m_ExportFlags;
59};
60
61} // namespace armnn