blob: 40f97cd28eb0e71db39c2c34a5c1316512f2f578 [file] [log] [blame]
Colm Donelan17948b52022-02-01 23:37:04 +00001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <armnn/backends/MemCopyWorkload.hpp>
7#include <armnnTestUtils/MockBackend.hpp>
8#include <armnnTestUtils/MockTensorHandle.hpp>
9
10namespace armnn
11{
12
13constexpr const char* MockBackendId()
14{
15 return "CpuMock";
16}
17
18namespace
19{
20static const BackendId s_Id{ MockBackendId() };
21}
22
23MockWorkloadFactory::MockWorkloadFactory(const std::shared_ptr<MockMemoryManager>& memoryManager)
24 : m_MemoryManager(memoryManager)
25{}
26
27MockWorkloadFactory::MockWorkloadFactory()
28 : m_MemoryManager(new MockMemoryManager())
29{}
30
31const BackendId& MockWorkloadFactory::GetBackendId() const
32{
33 return s_Id;
34}
35
36std::unique_ptr<IWorkload> MockWorkloadFactory::CreateWorkload(LayerType type,
37 const QueueDescriptor& descriptor,
38 const WorkloadInfo& info) const
39{
40 switch (type)
41 {
42 case LayerType::MemCopy: {
43 auto memCopyQueueDescriptor = PolymorphicDowncast<const MemCopyQueueDescriptor*>(&descriptor);
44 if (descriptor.m_Inputs.empty())
45 {
46 throw InvalidArgumentException("MockWorkloadFactory: CreateMemCopy() expected an input tensor.");
47 }
48 return std::make_unique<CopyMemGenericWorkload>(*memCopyQueueDescriptor, info);
49 }
50 default:
51 return nullptr;
52 }
53}
54
55} // namespace armnn