blob: 33f321653cc68d4961306a73dea44b1e2a4d1c11 [file] [log] [blame]
Narumol Prangnawaratd12b4072022-01-17 18:03:14 +00001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <cl/ClBackend.hpp>
7#include <cl/ClTensorHandleFactory.hpp>
8#include <cl/ClImportTensorHandleFactory.hpp>
9#include <cl/test/ClContextControlFixture.hpp>
10
11#include <doctest/doctest.h>
12
13using namespace armnn;
14
15TEST_SUITE("ClBackendTests")
16{
17TEST_CASE("ClRegisterTensorHandleFactoriesMatchingImportFactoryId")
18{
19 auto clBackend = std::make_unique<ClBackend>();
20 TensorHandleFactoryRegistry registry;
21 clBackend->RegisterTensorHandleFactories(registry);
22
23 // When calling RegisterTensorHandleFactories, CopyAndImportFactoryPair is registered
24 // Get ClImportTensorHandleFactory id as the matching import factory id
25 CHECK((registry.GetMatchingImportFactoryId(ClTensorHandleFactory::GetIdStatic()) ==
26 ClImportTensorHandleFactory::GetIdStatic()));
27}
28
29TEST_CASE("ClRegisterTensorHandleFactoriesWithMemorySourceFlagsMatchingImportFactoryId")
30{
31 auto clBackend = std::make_unique<ClBackend>();
32 TensorHandleFactoryRegistry registry;
33 clBackend->RegisterTensorHandleFactories(registry,
34 static_cast<MemorySourceFlags>(MemorySource::Malloc),
35 static_cast<MemorySourceFlags>(MemorySource::Malloc));
36
37 // When calling RegisterTensorHandleFactories with MemorySourceFlags, CopyAndImportFactoryPair is registered
38 // Get ClImportTensorHandleFactory id as the matching import factory id
39 CHECK((registry.GetMatchingImportFactoryId(ClTensorHandleFactory::GetIdStatic()) ==
40 ClImportTensorHandleFactory::GetIdStatic()));
41}
42
43TEST_CASE_FIXTURE(ClContextControlFixture, "ClCreateWorkloadFactoryMatchingImportFactoryId")
44{
45 auto clBackend = std::make_unique<ClBackend>();
46 TensorHandleFactoryRegistry registry;
47 clBackend->CreateWorkloadFactory(registry);
48
49 // When calling CreateWorkloadFactory, CopyAndImportFactoryPair is registered
50 // Get ClImportTensorHandleFactory id as the matching import factory id
51 CHECK((registry.GetMatchingImportFactoryId(ClTensorHandleFactory::GetIdStatic()) ==
52 ClImportTensorHandleFactory::GetIdStatic()));
53}
54
55TEST_CASE_FIXTURE(ClContextControlFixture, "ClCreateWorkloadFactoryWithOptionsMatchingImportFactoryId")
56{
57 auto clBackend = std::make_unique<ClBackend>();
58 TensorHandleFactoryRegistry registry;
59 ModelOptions modelOptions;
60 clBackend->CreateWorkloadFactory(registry, modelOptions);
61
62 // When calling CreateWorkloadFactory with ModelOptions, CopyAndImportFactoryPair is registered
63 // Get ClImportTensorHandleFactory id as the matching import factory id
64 CHECK((registry.GetMatchingImportFactoryId(ClTensorHandleFactory::GetIdStatic()) ==
65 ClImportTensorHandleFactory::GetIdStatic()));
66}
67
68TEST_CASE_FIXTURE(ClContextControlFixture, "ClCreateWorkloadFactoryWitMemoryFlagsMatchingImportFactoryId")
69{
70 auto clBackend = std::make_unique<ClBackend>();
71 TensorHandleFactoryRegistry registry;
72 ModelOptions modelOptions;
73 clBackend->CreateWorkloadFactory(registry, modelOptions,
74 static_cast<MemorySourceFlags>(MemorySource::Malloc),
75 static_cast<MemorySourceFlags>(MemorySource::Malloc));
76
77 // When calling CreateWorkloadFactory with ModelOptions and MemorySourceFlags,
78 // CopyAndImportFactoryPair is registered
79 // Get ClImportTensorHandleFactory id as the matching import factory id
80 CHECK((registry.GetMatchingImportFactoryId(ClTensorHandleFactory::GetIdStatic()) ==
81 ClImportTensorHandleFactory::GetIdStatic()));
82}
83}