blob: 0b261fba373b4de608f04d9b2ec481b63515ea3c [file] [log] [blame]
Samuel Yap6b478092022-07-06 15:36:03 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnnTestUtils/LayerTestResult.hpp>
9
10#include <ResolveType.hpp>
11
12#include <armnn/backends/IBackendInternal.hpp>
13
14template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>, std::size_t NumDims>
15LayerTestResult<T, NumDims> BatchMatMulTestImpl(
16 armnn::IWorkloadFactory& workloadFactory,
17 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
18 const armnn::ITensorHandleFactory& tensorHandleFactory,
19 armnn::BatchMatMulDescriptor descriptor,
20 const std::vector<T>& inputX,
21 const std::vector<T>& inputY,
22 const std::vector<T>& outputExpected,
23 const armnn::TensorInfo& inputXInfo,
24 const armnn::TensorInfo& inputYInfo,
25 const armnn::TensorInfo& outputInfo);
26
27template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
28LayerTestResult<T, 2> BatchMatMul2DSimpleTest(
29 armnn::IWorkloadFactory& workloadFactory,
30 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
31 const armnn::ITensorHandleFactory& tensorHandleFactory);
32
33template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
34LayerTestResult<T, 3> BatchMatMul3DSimpleTest(
35 armnn::IWorkloadFactory& workloadFactory,
36 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
37 const armnn::ITensorHandleFactory& tensorHandleFactory);
38
39template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
40LayerTestResult<T, 4> BatchMatMulNCHWSimpleTest(
41 armnn::IWorkloadFactory& workloadFactory,
42 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
43 const armnn::ITensorHandleFactory& tensorHandleFactory);
44
45template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
46LayerTestResult<T, 4> BatchMatMulNHWCSimpleTest(
47 armnn::IWorkloadFactory& workloadFactory,
48 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
49 const armnn::ITensorHandleFactory& tensorHandleFactory);
50
51template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
52LayerTestResult<T, 3> BatchMatMul3DBatchTest(
53 armnn::IWorkloadFactory& workloadFactory,
54 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
55 const armnn::ITensorHandleFactory& tensorHandleFactory);
56
57template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
58LayerTestResult<T, 3> BatchMatMul3DBroadcastTest(
59 armnn::IWorkloadFactory& workloadFactory,
60 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
61 const armnn::ITensorHandleFactory& tensorHandleFactory);
62
63template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
64LayerTestResult<T, 3> BatchMatMul3D2DBroadcastTest(
65 armnn::IWorkloadFactory& workloadFactory,
66 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
67 const armnn::ITensorHandleFactory& tensorHandleFactory);
68
69template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
70LayerTestResult<T, 5> BatchMatMulNDHWCNHWCTest(
71 armnn::IWorkloadFactory& workloadFactory,
72 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
73 const armnn::ITensorHandleFactory& tensorHandleFactory);
74
75template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
76LayerTestResult<T, 2> BatchMatMul2DTinyTest(
77 armnn::IWorkloadFactory& workloadFactory,
78 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
79 const armnn::ITensorHandleFactory& tensorHandleFactory);
80
81template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
82LayerTestResult<T, 3> BatchMatMul3DNonSquareTest(
83 armnn::IWorkloadFactory& workloadFactory,
84 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Samuel Yapdc8ed9d2022-08-08 14:07:42 +010085 const armnn::ITensorHandleFactory& tensorHandleFactory);
86
87template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
88LayerTestResult<T, 2> BatchMatMul2DTranspSimpleTest(
89 armnn::IWorkloadFactory& workloadFactory,
90 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
91 const armnn::ITensorHandleFactory& tensorHandleFactory);
92
93template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
94LayerTestResult<T, 2> BatchMatMul2DAdjointSimpleTest(
95 armnn::IWorkloadFactory& workloadFactory,
96 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
97 const armnn::ITensorHandleFactory& tensorHandleFactory);
98
99template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
100LayerTestResult<T, 4> BatchMatMulNHWCParamsTest(
101 armnn::IWorkloadFactory& workloadFactory,
102 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Samuel Yap6b478092022-07-06 15:36:03 +0100103 const armnn::ITensorHandleFactory& tensorHandleFactory);