blob: 9e2139667b8bdf3d857f90f0f2a8835deef0d382 [file] [log] [blame]
Samuel Yap6b478092022-07-06 15:36:03 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnnTestUtils/LayerTestResult.hpp>
9
10#include <ResolveType.hpp>
11
12#include <armnn/backends/IBackendInternal.hpp>
13
14template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>, std::size_t NumDims>
15LayerTestResult<T, NumDims> BatchMatMulTestImpl(
16 armnn::IWorkloadFactory& workloadFactory,
17 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
18 const armnn::ITensorHandleFactory& tensorHandleFactory,
19 armnn::BatchMatMulDescriptor descriptor,
20 const std::vector<T>& inputX,
21 const std::vector<T>& inputY,
22 const std::vector<T>& outputExpected,
23 const armnn::TensorInfo& inputXInfo,
24 const armnn::TensorInfo& inputYInfo,
25 const armnn::TensorInfo& outputInfo);
26
27template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
28LayerTestResult<T, 2> BatchMatMul2DSimpleTest(
29 armnn::IWorkloadFactory& workloadFactory,
30 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
31 const armnn::ITensorHandleFactory& tensorHandleFactory);
32
33template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
34LayerTestResult<T, 3> BatchMatMul3DSimpleTest(
35 armnn::IWorkloadFactory& workloadFactory,
36 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
37 const armnn::ITensorHandleFactory& tensorHandleFactory);
38
39template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
40LayerTestResult<T, 4> BatchMatMulNCHWSimpleTest(
41 armnn::IWorkloadFactory& workloadFactory,
42 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
43 const armnn::ITensorHandleFactory& tensorHandleFactory);
44
45template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
46LayerTestResult<T, 4> BatchMatMulNHWCSimpleTest(
47 armnn::IWorkloadFactory& workloadFactory,
48 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
49 const armnn::ITensorHandleFactory& tensorHandleFactory);
50
51template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
52LayerTestResult<T, 3> BatchMatMul3DBatchTest(
53 armnn::IWorkloadFactory& workloadFactory,
54 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
55 const armnn::ITensorHandleFactory& tensorHandleFactory);
56
57template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
58LayerTestResult<T, 3> BatchMatMul3DBroadcastTest(
59 armnn::IWorkloadFactory& workloadFactory,
60 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
61 const armnn::ITensorHandleFactory& tensorHandleFactory);
62
63template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
64LayerTestResult<T, 3> BatchMatMul3D2DBroadcastTest(
65 armnn::IWorkloadFactory& workloadFactory,
66 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
67 const armnn::ITensorHandleFactory& tensorHandleFactory);
68
69template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
70LayerTestResult<T, 5> BatchMatMulNDHWCNHWCTest(
71 armnn::IWorkloadFactory& workloadFactory,
72 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
73 const armnn::ITensorHandleFactory& tensorHandleFactory);
74
75template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
76LayerTestResult<T, 2> BatchMatMul2DTinyTest(
77 armnn::IWorkloadFactory& workloadFactory,
78 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
79 const armnn::ITensorHandleFactory& tensorHandleFactory);
80
81template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
82LayerTestResult<T, 3> BatchMatMul3DNonSquareTest(
83 armnn::IWorkloadFactory& workloadFactory,
84 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
85 const armnn::ITensorHandleFactory& tensorHandleFactory);