blob: 8fbfeeae3da2c6053f7917f47d52faff41221a17 [file] [log] [blame]
narpra014951d842019-01-18 16:53:53 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include "WorkloadTestUtils.hpp"
8
9#include <armnn/Types.hpp>
10#include <backendsCommon/CpuTensorHandle.hpp>
11#include <backendsCommon/IBackendInternal.hpp>
12#include <backendsCommon/WorkloadFactory.hpp>
13
14template <armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>,
15 unsigned int paramsDim, unsigned int indicesDim, unsigned int OutputDim>
16LayerTestResult<T, OutputDim> GatherTestImpl(
17 armnn::IWorkloadFactory& workloadFactory,
18 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
19 const armnn::TensorInfo& paramsInfo,
20 const armnn::TensorInfo& indicesInfo,
21 const armnn::TensorInfo& outputInfo,
22 const std::vector<T>& paramsData,
23 const std::vector<int32_t>& indicesData,
24 const std::vector<T>& outputData)
25{
26 auto params = MakeTensor<T, paramsDim>(paramsInfo, paramsData);
27 auto indices = MakeTensor<int32_t, indicesDim>(indicesInfo, indicesData);
28
29 LayerTestResult<T, OutputDim> result(outputInfo);
30 result.outputExpected = MakeTensor<T, OutputDim>(outputInfo, outputData);
31
32 std::unique_ptr<armnn::ITensorHandle> paramsHandle = workloadFactory.CreateTensorHandle(paramsInfo);
33 std::unique_ptr<armnn::ITensorHandle> indicesHandle = workloadFactory.CreateTensorHandle(indicesInfo);
34 std::unique_ptr<armnn::ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputInfo);
35
36 armnn::GatherQueueDescriptor data;
37 armnn::WorkloadInfo info;
38 AddInputToWorkload(data, info, paramsInfo, paramsHandle.get());
39 AddInputToWorkload(data, info, indicesInfo, indicesHandle.get());
40 AddOutputToWorkload(data, info, outputInfo, outputHandle.get());
41
42 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateGather(data, info);
43
44 paramsHandle->Allocate();
45 indicesHandle->Allocate();
46 outputHandle->Allocate();
47
48 CopyDataToITensorHandle(paramsHandle.get(), params.origin());
49 CopyDataToITensorHandle(indicesHandle.get(), indices.origin());
50
51 workload->Execute();
52
53 CopyDataFromITensorHandle(result.output.origin(), outputHandle.get());
54
55 return result;
56}
57
58template <armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
59LayerTestResult<T, 1> Gather1DParamsTestImpl(armnn::IWorkloadFactory& workloadFactory,
60 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
61{
62 armnn::TensorInfo paramsInfo({ 8 }, ArmnnType);
63 armnn::TensorInfo indicesInfo({ 4 }, armnn::DataType::Signed32);
64 armnn::TensorInfo outputInfo({ 4 }, ArmnnType);
65
Ellen Norris-Thompson6858d3f2019-06-21 15:50:00 +010066 if (armnn::IsQuantizedType<T>())
67 {
68 paramsInfo.SetQuantizationScale(1.0f);
69 paramsInfo.SetQuantizationOffset(1);
70 outputInfo.SetQuantizationScale(1.0f);
71 outputInfo.SetQuantizationOffset(1);
72 }
narpra014951d842019-01-18 16:53:53 +000073 const std::vector<T> params = std::vector<T>({ 1, 2, 3, 4, 5, 6, 7, 8 });
74 const std::vector<int32_t> indices = std::vector<int32_t>({ 0, 2, 1, 5 });
75 const std::vector<T> expectedOutput = std::vector<T>({ 1, 3, 2, 6 });
76
77 return GatherTestImpl<ArmnnType, T, 1, 1, 1>(workloadFactory, memoryManager,
78 paramsInfo, indicesInfo, outputInfo,
79 params,indices, expectedOutput);
80}
81
82template <armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
83LayerTestResult<T, 2> GatherMultiDimParamsTestImpl(
84 armnn::IWorkloadFactory& workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
85{
86 armnn::TensorInfo paramsInfo({ 5, 2 }, ArmnnType);
87 armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32);
88 armnn::TensorInfo outputInfo({ 3, 2 }, ArmnnType);
89
Ellen Norris-Thompson6858d3f2019-06-21 15:50:00 +010090 if (armnn::IsQuantizedType<T>())
91 {
92 paramsInfo.SetQuantizationScale(1.0f);
93 paramsInfo.SetQuantizationOffset(1);
94 outputInfo.SetQuantizationScale(1.0f);
95 outputInfo.SetQuantizationOffset(1);
96 }
97
narpra014951d842019-01-18 16:53:53 +000098 const std::vector<T> params = std::vector<T>({ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 });
99 const std::vector<int32_t> indices = std::vector<int32_t>({ 1, 3, 4 });
100 const std::vector<T> expectedOutput = std::vector<T>({ 3, 4, 7, 8, 9, 10 });
101
102 return GatherTestImpl<ArmnnType, T, 2, 1, 2>(workloadFactory, memoryManager,
103 paramsInfo, indicesInfo, outputInfo,
104 params,indices, expectedOutput);
105}
106
107template <armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
108LayerTestResult<T, 4> GatherMultiDimParamsMultiDimIndicesTestImpl(
109 armnn::IWorkloadFactory& workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
110{
111 armnn::TensorInfo paramsInfo({ 3, 2, 3}, ArmnnType);
112 armnn::TensorInfo indicesInfo({ 2, 3 }, armnn::DataType::Signed32);
113 armnn::TensorInfo outputInfo({ 2, 3, 2, 3 }, ArmnnType);
114
Ellen Norris-Thompson6858d3f2019-06-21 15:50:00 +0100115 if (armnn::IsQuantizedType<T>())
116 {
117 paramsInfo.SetQuantizationScale(1.0f);
118 paramsInfo.SetQuantizationOffset(1);
119 outputInfo.SetQuantizationScale(1.0f);
120 outputInfo.SetQuantizationOffset(1);
121 }
122
narpra014951d842019-01-18 16:53:53 +0000123 const std::vector<T> params = std::vector<T>({
124 1, 2, 3,
125 4, 5, 6,
126
127 7, 8, 9,
128 10, 11, 12,
129
130 13, 14, 15,
131 16, 17, 18 });
132 const std::vector<int32_t> indices = std::vector<int32_t>({ 1, 2, 1, 2, 1, 0 });
133 const std::vector<T> expectedOutput = std::vector<T>({
134 7, 8, 9,
135 10, 11, 12,
136 13, 14, 15,
137 16, 17, 18,
138 7, 8, 9,
139 10, 11, 12,
140
141 13, 14, 15,
142 16, 17, 18,
143 7, 8, 9,
144 10, 11, 12,
145 1, 2, 3,
146 4, 5, 6 });
147
148 return GatherTestImpl<ArmnnType, T, 3, 2, 4>(workloadFactory, memoryManager,
149 paramsInfo, indicesInfo, outputInfo,
150 params,indices, expectedOutput);
151}