blob: e9614a7c62bbd98130c91cfc612238c45d513321 [file] [log] [blame]
Narumol Prangnawaratd12b4072022-01-17 18:03:14 +00001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
Matthew Benthamc1c5f2a2023-03-30 14:24:46 +00006#include <backendsCommon/TensorHandleFactoryRegistry.hpp>
Narumol Prangnawaratd12b4072022-01-17 18:03:14 +00007#include <cl/ClBackend.hpp>
8#include <cl/ClTensorHandleFactory.hpp>
9#include <cl/ClImportTensorHandleFactory.hpp>
10#include <cl/test/ClContextControlFixture.hpp>
11
12#include <doctest/doctest.h>
13
14using namespace armnn;
15
16TEST_SUITE("ClBackendTests")
17{
18TEST_CASE("ClRegisterTensorHandleFactoriesMatchingImportFactoryId")
19{
20 auto clBackend = std::make_unique<ClBackend>();
21 TensorHandleFactoryRegistry registry;
22 clBackend->RegisterTensorHandleFactories(registry);
23
24 // When calling RegisterTensorHandleFactories, CopyAndImportFactoryPair is registered
25 // Get ClImportTensorHandleFactory id as the matching import factory id
26 CHECK((registry.GetMatchingImportFactoryId(ClTensorHandleFactory::GetIdStatic()) ==
27 ClImportTensorHandleFactory::GetIdStatic()));
28}
29
30TEST_CASE("ClRegisterTensorHandleFactoriesWithMemorySourceFlagsMatchingImportFactoryId")
31{
32 auto clBackend = std::make_unique<ClBackend>();
33 TensorHandleFactoryRegistry registry;
34 clBackend->RegisterTensorHandleFactories(registry,
35 static_cast<MemorySourceFlags>(MemorySource::Malloc),
36 static_cast<MemorySourceFlags>(MemorySource::Malloc));
37
38 // When calling RegisterTensorHandleFactories with MemorySourceFlags, CopyAndImportFactoryPair is registered
39 // Get ClImportTensorHandleFactory id as the matching import factory id
40 CHECK((registry.GetMatchingImportFactoryId(ClTensorHandleFactory::GetIdStatic()) ==
41 ClImportTensorHandleFactory::GetIdStatic()));
42}
43
44TEST_CASE_FIXTURE(ClContextControlFixture, "ClCreateWorkloadFactoryMatchingImportFactoryId")
45{
46 auto clBackend = std::make_unique<ClBackend>();
47 TensorHandleFactoryRegistry registry;
48 clBackend->CreateWorkloadFactory(registry);
49
50 // When calling CreateWorkloadFactory, CopyAndImportFactoryPair is registered
51 // Get ClImportTensorHandleFactory id as the matching import factory id
52 CHECK((registry.GetMatchingImportFactoryId(ClTensorHandleFactory::GetIdStatic()) ==
53 ClImportTensorHandleFactory::GetIdStatic()));
54}
55
56TEST_CASE_FIXTURE(ClContextControlFixture, "ClCreateWorkloadFactoryWithOptionsMatchingImportFactoryId")
57{
58 auto clBackend = std::make_unique<ClBackend>();
59 TensorHandleFactoryRegistry registry;
60 ModelOptions modelOptions;
61 clBackend->CreateWorkloadFactory(registry, modelOptions);
62
63 // When calling CreateWorkloadFactory with ModelOptions, CopyAndImportFactoryPair is registered
64 // Get ClImportTensorHandleFactory id as the matching import factory id
65 CHECK((registry.GetMatchingImportFactoryId(ClTensorHandleFactory::GetIdStatic()) ==
66 ClImportTensorHandleFactory::GetIdStatic()));
67}
68
69TEST_CASE_FIXTURE(ClContextControlFixture, "ClCreateWorkloadFactoryWitMemoryFlagsMatchingImportFactoryId")
70{
71 auto clBackend = std::make_unique<ClBackend>();
72 TensorHandleFactoryRegistry registry;
73 ModelOptions modelOptions;
74 clBackend->CreateWorkloadFactory(registry, modelOptions,
75 static_cast<MemorySourceFlags>(MemorySource::Malloc),
76 static_cast<MemorySourceFlags>(MemorySource::Malloc));
77
78 // When calling CreateWorkloadFactory with ModelOptions and MemorySourceFlags,
79 // CopyAndImportFactoryPair is registered
80 // Get ClImportTensorHandleFactory id as the matching import factory id
81 CHECK((registry.GetMatchingImportFactoryId(ClTensorHandleFactory::GetIdStatic()) ==
82 ClImportTensorHandleFactory::GetIdStatic()));
83}
84}