blob: 16b266ed29d89962c6f6999a60369cb2cf79e71e [file] [log] [blame]
narpra014951d842019-01-18 16:53:53 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include "WorkloadTestUtils.hpp"
8
9#include <armnn/Types.hpp>
10#include <backendsCommon/CpuTensorHandle.hpp>
11#include <backendsCommon/IBackendInternal.hpp>
12#include <backendsCommon/WorkloadFactory.hpp>
13
14template <armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>,
15 unsigned int paramsDim, unsigned int indicesDim, unsigned int OutputDim>
16LayerTestResult<T, OutputDim> GatherTestImpl(
17 armnn::IWorkloadFactory& workloadFactory,
18 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
19 const armnn::TensorInfo& paramsInfo,
20 const armnn::TensorInfo& indicesInfo,
21 const armnn::TensorInfo& outputInfo,
22 const std::vector<T>& paramsData,
23 const std::vector<int32_t>& indicesData,
24 const std::vector<T>& outputData)
25{
26 auto params = MakeTensor<T, paramsDim>(paramsInfo, paramsData);
27 auto indices = MakeTensor<int32_t, indicesDim>(indicesInfo, indicesData);
28
29 LayerTestResult<T, OutputDim> result(outputInfo);
30 result.outputExpected = MakeTensor<T, OutputDim>(outputInfo, outputData);
31
32 std::unique_ptr<armnn::ITensorHandle> paramsHandle = workloadFactory.CreateTensorHandle(paramsInfo);
33 std::unique_ptr<armnn::ITensorHandle> indicesHandle = workloadFactory.CreateTensorHandle(indicesInfo);
34 std::unique_ptr<armnn::ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputInfo);
35
36 armnn::GatherQueueDescriptor data;
37 armnn::WorkloadInfo info;
38 AddInputToWorkload(data, info, paramsInfo, paramsHandle.get());
39 AddInputToWorkload(data, info, indicesInfo, indicesHandle.get());
40 AddOutputToWorkload(data, info, outputInfo, outputHandle.get());
41
42 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateGather(data, info);
43
44 paramsHandle->Allocate();
45 indicesHandle->Allocate();
46 outputHandle->Allocate();
47
48 CopyDataToITensorHandle(paramsHandle.get(), params.origin());
49 CopyDataToITensorHandle(indicesHandle.get(), indices.origin());
50
51 workload->Execute();
52
53 CopyDataFromITensorHandle(result.output.origin(), outputHandle.get());
54
55 return result;
56}
57
58template <armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
59LayerTestResult<T, 1> Gather1DParamsTestImpl(armnn::IWorkloadFactory& workloadFactory,
60 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
61{
62 armnn::TensorInfo paramsInfo({ 8 }, ArmnnType);
63 armnn::TensorInfo indicesInfo({ 4 }, armnn::DataType::Signed32);
64 armnn::TensorInfo outputInfo({ 4 }, ArmnnType);
65
66 const std::vector<T> params = std::vector<T>({ 1, 2, 3, 4, 5, 6, 7, 8 });
67 const std::vector<int32_t> indices = std::vector<int32_t>({ 0, 2, 1, 5 });
68 const std::vector<T> expectedOutput = std::vector<T>({ 1, 3, 2, 6 });
69
70 return GatherTestImpl<ArmnnType, T, 1, 1, 1>(workloadFactory, memoryManager,
71 paramsInfo, indicesInfo, outputInfo,
72 params,indices, expectedOutput);
73}
74
75template <armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
76LayerTestResult<T, 2> GatherMultiDimParamsTestImpl(
77 armnn::IWorkloadFactory& workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
78{
79 armnn::TensorInfo paramsInfo({ 5, 2 }, ArmnnType);
80 armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32);
81 armnn::TensorInfo outputInfo({ 3, 2 }, ArmnnType);
82
83 const std::vector<T> params = std::vector<T>({ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 });
84 const std::vector<int32_t> indices = std::vector<int32_t>({ 1, 3, 4 });
85 const std::vector<T> expectedOutput = std::vector<T>({ 3, 4, 7, 8, 9, 10 });
86
87 return GatherTestImpl<ArmnnType, T, 2, 1, 2>(workloadFactory, memoryManager,
88 paramsInfo, indicesInfo, outputInfo,
89 params,indices, expectedOutput);
90}
91
92template <armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
93LayerTestResult<T, 4> GatherMultiDimParamsMultiDimIndicesTestImpl(
94 armnn::IWorkloadFactory& workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
95{
96 armnn::TensorInfo paramsInfo({ 3, 2, 3}, ArmnnType);
97 armnn::TensorInfo indicesInfo({ 2, 3 }, armnn::DataType::Signed32);
98 armnn::TensorInfo outputInfo({ 2, 3, 2, 3 }, ArmnnType);
99
100 const std::vector<T> params = std::vector<T>({
101 1, 2, 3,
102 4, 5, 6,
103
104 7, 8, 9,
105 10, 11, 12,
106
107 13, 14, 15,
108 16, 17, 18 });
109 const std::vector<int32_t> indices = std::vector<int32_t>({ 1, 2, 1, 2, 1, 0 });
110 const std::vector<T> expectedOutput = std::vector<T>({
111 7, 8, 9,
112 10, 11, 12,
113 13, 14, 15,
114 16, 17, 18,
115 7, 8, 9,
116 10, 11, 12,
117
118 13, 14, 15,
119 16, 17, 18,
120 7, 8, 9,
121 10, 11, 12,
122 1, 2, 3,
123 4, 5, 6 });
124
125 return GatherTestImpl<ArmnnType, T, 3, 2, 4>(workloadFactory, memoryManager,
126 paramsInfo, indicesInfo, outputInfo,
127 params,indices, expectedOutput);
128}