blob: 562a45eea3c5af0b8f253614f5deaed2c4071e86 [file] [log] [blame]
Narumol Prangnawarat77400452022-01-13 17:43:41 +00001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
Matthew Benthamc1c5f2a2023-03-30 14:24:46 +00006#include <backendsCommon/TensorHandleFactoryRegistry.hpp>
Narumol Prangnawarat77400452022-01-13 17:43:41 +00007#include <neon/NeonBackend.hpp>
8#include <neon/NeonTensorHandleFactory.hpp>
9
10#include <doctest/doctest.h>
11
12using namespace armnn;
13
14TEST_SUITE("NeonBackendTests")
15{
16TEST_CASE("NeonRegisterTensorHandleFactoriesMatchingImportFactoryId")
17{
18 auto neonBackend = std::make_unique<NeonBackend>();
19 TensorHandleFactoryRegistry registry;
20 neonBackend->RegisterTensorHandleFactories(registry);
21
22 // When calling RegisterTensorHandleFactories, CopyAndImportFactoryPair is registered
23 // Get matching import factory id correctly
24 CHECK((registry.GetMatchingImportFactoryId(NeonTensorHandleFactory::GetIdStatic()) ==
25 NeonTensorHandleFactory::GetIdStatic()));
26}
27
28TEST_CASE("NeonCreateWorkloadFactoryMatchingImportFactoryId")
29{
30 auto neonBackend = std::make_unique<NeonBackend>();
31 TensorHandleFactoryRegistry registry;
32 neonBackend->CreateWorkloadFactory(registry);
33
34 // When calling CreateWorkloadFactory, CopyAndImportFactoryPair is registered
35 // Get matching import factory id correctly
36 CHECK((registry.GetMatchingImportFactoryId(NeonTensorHandleFactory::GetIdStatic()) ==
37 NeonTensorHandleFactory::GetIdStatic()));
38}
39
40TEST_CASE("NeonCreateWorkloadFactoryWithOptionsMatchingImportFactoryId")
41{
42 auto neonBackend = std::make_unique<NeonBackend>();
43 TensorHandleFactoryRegistry registry;
44 ModelOptions modelOptions;
45 neonBackend->CreateWorkloadFactory(registry, modelOptions);
46
47 // When calling CreateWorkloadFactory with ModelOptions, CopyAndImportFactoryPair is registered
48 // Get matching import factory id correctly
49 CHECK((registry.GetMatchingImportFactoryId(NeonTensorHandleFactory::GetIdStatic()) ==
50 NeonTensorHandleFactory::GetIdStatic()));
51}
52}