blob: 71324dbf818189a05c709fcfff065a2b158064e3 [file] [log] [blame]
Sadik Armagana097d2a2021-11-24 15:47:28 +00001//
2// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
Colm Donelanc42a9872022-02-02 16:35:09 +00007#include "armnnTestUtils/TensorHelpers.hpp"
Sadik Armagana097d2a2021-11-24 15:47:28 +00008
9#include <armnn/Logging.hpp>
10#include <armnn/Utils.hpp>
11#include <reference/RefWorkloadFactory.hpp>
12#include <reference/test/RefWorkloadFactoryHelper.hpp>
13
14#include <backendsCommon/test/WorkloadFactoryHelper.hpp>
15
16#include <armnnTestUtils/LayerTestResult.hpp>
17#include <armnnTestUtils/TensorCopyUtils.hpp>
Colm Donelan0c479742021-12-10 12:43:54 +000018#include <armnnTestUtils/WorkloadTestUtils.hpp>
Sadik Armagana097d2a2021-11-24 15:47:28 +000019
20#include <doctest/doctest.h>
21
22inline void ConfigureLoggingTest()
23{
24 // Configures logging for both the ARMNN library and this test program.
25 armnn::ConfigureLogging(true, true, armnn::LogSeverity::Fatal);
26}
27
28// The following macros require the caller to have defined FactoryType, with one of the following using statements:
29//
30// using FactoryType = armnn::RefWorkloadFactory;
31// using FactoryType = armnn::ClWorkloadFactory;
32// using FactoryType = armnn::NeonWorkloadFactory;
33
34/// Executes CHECK_MESSAGE on CompareTensors() return value so that the predicate_result message is reported.
35/// If the test reports itself as not supported then the tensors are not compared.
36/// Additionally this checks that the supportedness reported by the test matches the name of the test.
37/// Unsupported tests must be 'tagged' by including "UNSUPPORTED" in their name.
38/// This is useful because it clarifies that the feature being tested is not actually supported
39/// (a passed test with the name of a feature would imply that feature was supported).
40/// If support is added for a feature, the test case will fail because the name incorrectly contains UNSUPPORTED.
41/// If support is removed for a feature, the test case will fail because the name doesn't contain UNSUPPORTED.
42template <typename T, std::size_t n>
43void CompareTestResultIfSupported(const std::string& testName, const LayerTestResult<T, n>& testResult)
44{
45 bool testNameIndicatesUnsupported = testName.find("UNSUPPORTED") != std::string::npos;
46 CHECK_MESSAGE(testNameIndicatesUnsupported != testResult.m_Supported,
47 "The test name does not match the supportedness it is reporting");
48 if (testResult.m_Supported)
49 {
50 auto result = CompareTensors(testResult.m_ActualData,
51 testResult.m_ExpectedData,
52 testResult.m_ActualShape,
53 testResult.m_ExpectedShape,
54 testResult.m_CompareBoolean);
55 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
56 }
57}
58
59template <typename T, std::size_t n>
60void CompareTestResultIfSupported(const std::string& testName, const std::vector<LayerTestResult<T, n>>& testResult)
61{
62 bool testNameIndicatesUnsupported = testName.find("UNSUPPORTED") != std::string::npos;
63 for (unsigned int i = 0; i < testResult.size(); ++i)
64 {
65 CHECK_MESSAGE(testNameIndicatesUnsupported != testResult[i].m_Supported,
66 "The test name does not match the supportedness it is reporting");
67 if (testResult[i].m_Supported)
68 {
69 auto result = CompareTensors(testResult[i].m_ActualData,
70 testResult[i].m_ExpectedData,
71 testResult[i].m_ActualShape,
72 testResult[i].m_ExpectedShape);
73 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
74 }
75 }
76}
77
78template<typename FactoryType, typename TFuncPtr, typename... Args>
79void RunTestFunction(const char* testName, TFuncPtr testFunction, Args... args)
80{
81 std::unique_ptr<armnn::IProfiler> profiler = std::make_unique<armnn::IProfiler>();
82 armnn::ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
83
84 auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager();
85 FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager);
86
87 auto testResult = (*testFunction)(workloadFactory, memoryManager, args...);
88 CompareTestResultIfSupported(testName, testResult);
89
90 armnn::ProfilerManager::GetInstance().RegisterProfiler(nullptr);
91}
92
93
94template<typename FactoryType, typename TFuncPtr, typename... Args>
95void RunTestFunctionUsingTensorHandleFactory(const char* testName, TFuncPtr testFunction, Args... args)
96{
97 std::unique_ptr<armnn::IProfiler> profiler = std::make_unique<armnn::IProfiler>();
98 armnn::ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
99
100 auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager();
101 FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager);
102
103 auto tensorHandleFactory = WorkloadFactoryHelper<FactoryType>::GetTensorHandleFactory(memoryManager);
104
105 auto testResult = (*testFunction)(workloadFactory, memoryManager, tensorHandleFactory, args...);
106 CompareTestResultIfSupported(testName, testResult);
107
108 armnn::ProfilerManager::GetInstance().RegisterProfiler(nullptr);
109}
110
111#define ARMNN_SIMPLE_TEST_CASE(TestName, TestFunction) \
112 TEST_CASE(#TestName) \
113 { \
114 TestFunction(); \
115 }
116
117#define ARMNN_AUTO_TEST_CASE(TestName, TestFunction, ...) \
118 TEST_CASE(#TestName) \
119 { \
120 RunTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
121 }
122
123#define ARMNN_AUTO_TEST_FIXTURE(TestName, Fixture, TestFunction, ...) \
124 TEST_CASE_FIXTURE(Fixture, #TestName) \
125 { \
126 RunTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
127 }
128
129#define ARMNN_AUTO_TEST_CASE_WITH_THF(TestName, TestFunction, ...) \
130 TEST_CASE(#TestName) \
131 { \
132 RunTestFunctionUsingTensorHandleFactory<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
133 }
134
135#define ARMNN_AUTO_TEST_FIXTURE_WITH_THF(TestName, Fixture, TestFunction, ...) \
136 TEST_CASE_FIXTURE(Fixture, #TestName) \
137 { \
138 RunTestFunctionUsingTensorHandleFactory<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
139 }
140
141template<typename FactoryType, typename TFuncPtr, typename... Args>
142void CompareRefTestFunction(const char* testName, TFuncPtr testFunction, Args... args)
143{
144 auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager();
145 FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager);
146
147 armnn::RefWorkloadFactory refWorkloadFactory;
148
149 auto testResult = (*testFunction)(workloadFactory, memoryManager, refWorkloadFactory, args...);
150 CompareTestResultIfSupported(testName, testResult);
151}
152
153template<typename FactoryType, typename TFuncPtr, typename... Args>
154void CompareRefTestFunctionUsingTensorHandleFactory(const char* testName, TFuncPtr testFunction, Args... args)
155{
156 auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager();
157 FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager);
Matthew Bentham79bb6532022-02-11 08:29:42 +0000158 auto tensorHandleFactory = WorkloadFactoryHelper<FactoryType>::GetTensorHandleFactory(memoryManager);
Sadik Armagana097d2a2021-11-24 15:47:28 +0000159
160 armnn::RefWorkloadFactory refWorkloadFactory;
Matthew Bentham79bb6532022-02-11 08:29:42 +0000161 auto refMemoryManager = WorkloadFactoryHelper<armnn::RefWorkloadFactory>::GetMemoryManager();
162 auto refTensorHandleFactory = RefWorkloadFactoryHelper::GetTensorHandleFactory(refMemoryManager);
Sadik Armagana097d2a2021-11-24 15:47:28 +0000163
164 auto testResult = (*testFunction)(
165 workloadFactory, memoryManager, refWorkloadFactory, tensorHandleFactory, refTensorHandleFactory, args...);
166 CompareTestResultIfSupported(testName, testResult);
167}
168
169#define ARMNN_COMPARE_REF_AUTO_TEST_CASE(TestName, TestFunction, ...) \
170 TEST_CASE(#TestName) \
171 { \
172 CompareRefTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
173 }
174
175#define ARMNN_COMPARE_REF_AUTO_TEST_CASE_WITH_THF(TestName, TestFunction, ...) \
176 TEST_CASE(#TestName) \
177 { \
178 CompareRefTestFunctionUsingTensorHandleFactory<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
179 }
180
181#define ARMNN_COMPARE_REF_FIXTURE_TEST_CASE(TestName, Fixture, TestFunction, ...) \
182 TEST_CASE_FIXTURE(Fixture, #TestName) \
183 { \
184 CompareRefTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
185 }
186
187#define ARMNN_COMPARE_REF_FIXTURE_TEST_CASE_WITH_THF(TestName, Fixture, TestFunction, ...) \
188 TEST_CASE_FIXTURE(Fixture, #TestName) \
189 { \
190 CompareRefTestFunctionUsingTensorHandleFactory<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
191 }