Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 1 | // |
Teresa Charlin | fbf0e5b | 2020-08-17 01:01:06 +0100 | [diff] [blame] | 2 | // Copyright © 2017 Arm Ltd and Contributors. All rights reserved. |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 3 | // SPDX-License-Identifier: MIT |
| 4 | // |
| 5 | |
| 6 | #include "GatherTestImpl.hpp" |
| 7 | |
| 8 | #include <ResolveType.hpp> |
| 9 | |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 10 | |
| 11 | #include <backendsCommon/test/TensorCopyUtils.hpp> |
| 12 | #include <backendsCommon/test/WorkloadTestUtils.hpp> |
| 13 | |
| 14 | #include <test/TensorHelpers.hpp> |
| 15 | |
| 16 | namespace |
| 17 | { |
| 18 | |
| 19 | template <armnn::DataType ArmnnType, |
| 20 | typename T = armnn::ResolveType<ArmnnType>, |
| 21 | size_t ParamsDim, |
| 22 | size_t IndicesDim, |
| 23 | size_t OutputDim> |
| 24 | LayerTestResult<T, OutputDim> GatherTestImpl( |
| 25 | armnn::IWorkloadFactory& workloadFactory, |
| 26 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 27 | const armnn::ITensorHandleFactory& tensorHandleFactory, |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 28 | const armnn::TensorInfo& paramsInfo, |
| 29 | const armnn::TensorInfo& indicesInfo, |
| 30 | const armnn::TensorInfo& outputInfo, |
| 31 | const std::vector<T>& paramsData, |
| 32 | const std::vector<int32_t>& indicesData, |
| 33 | const std::vector<T>& outputData) |
| 34 | { |
Jan Eilers | 8eb2560 | 2020-03-09 12:13:48 +0000 | [diff] [blame] | 35 | IgnoreUnused(memoryManager); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 36 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 37 | std::vector<T> actualOutput(outputInfo.GetNumElements()); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 38 | |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 39 | std::unique_ptr<armnn::ITensorHandle> paramsHandle = tensorHandleFactory.CreateTensorHandle(paramsInfo); |
| 40 | std::unique_ptr<armnn::ITensorHandle> indicesHandle = tensorHandleFactory.CreateTensorHandle(indicesInfo); |
| 41 | std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputInfo); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 42 | |
| 43 | armnn::GatherQueueDescriptor data; |
| 44 | armnn::WorkloadInfo info; |
| 45 | AddInputToWorkload(data, info, paramsInfo, paramsHandle.get()); |
| 46 | AddInputToWorkload(data, info, indicesInfo, indicesHandle.get()); |
| 47 | AddOutputToWorkload(data, info, outputInfo, outputHandle.get()); |
| 48 | |
| 49 | std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateGather(data, info); |
| 50 | |
| 51 | paramsHandle->Allocate(); |
| 52 | indicesHandle->Allocate(); |
| 53 | outputHandle->Allocate(); |
| 54 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 55 | CopyDataToITensorHandle(paramsHandle.get(), paramsData.data()); |
| 56 | CopyDataToITensorHandle(indicesHandle.get(), indicesData.data()); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 57 | |
| 58 | workload->Execute(); |
| 59 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 60 | CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get()); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 61 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 62 | return LayerTestResult<T, OutputDim>(actualOutput, |
| 63 | outputData, |
| 64 | outputHandle->GetShape(), |
| 65 | outputInfo.GetShape()); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 66 | } |
| 67 | |
Matthew Jackson | 9bff144 | 2019-09-12 09:08:23 +0100 | [diff] [blame] | 68 | template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> |
| 69 | struct GatherTestHelper |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 70 | { |
Matthew Jackson | 9bff144 | 2019-09-12 09:08:23 +0100 | [diff] [blame] | 71 | static LayerTestResult<T, 1> Gather1dParamsTestImpl( |
| 72 | armnn::IWorkloadFactory& workloadFactory, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 73 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
| 74 | const armnn::ITensorHandleFactory& tensorHandleFactory) |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 75 | { |
Matthew Jackson | 9bff144 | 2019-09-12 09:08:23 +0100 | [diff] [blame] | 76 | armnn::TensorInfo paramsInfo({ 8 }, ArmnnType); |
| 77 | armnn::TensorInfo indicesInfo({ 4 }, armnn::DataType::Signed32); |
| 78 | armnn::TensorInfo outputInfo({ 4 }, ArmnnType); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 79 | |
Matthew Jackson | 9bff144 | 2019-09-12 09:08:23 +0100 | [diff] [blame] | 80 | if (armnn::IsQuantizedType<T>()) |
| 81 | { |
| 82 | paramsInfo.SetQuantizationScale(1.0f); |
| 83 | paramsInfo.SetQuantizationOffset(1); |
| 84 | outputInfo.SetQuantizationScale(1.0f); |
| 85 | outputInfo.SetQuantizationOffset(1); |
| 86 | } |
| 87 | const std::vector<T> params = std::vector<T>({ 1, 2, 3, 4, 5, 6, 7, 8 }); |
| 88 | const std::vector<int32_t> indices = std::vector<int32_t>({ 0, 2, 1, 5 }); |
| 89 | const std::vector<T> expectedOutput = std::vector<T>({ 1, 3, 2, 6 }); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 90 | |
Matthew Jackson | 9bff144 | 2019-09-12 09:08:23 +0100 | [diff] [blame] | 91 | return GatherTestImpl<ArmnnType, T, 1, 1, 1>( |
| 92 | workloadFactory, |
| 93 | memoryManager, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 94 | tensorHandleFactory, |
Matthew Jackson | 9bff144 | 2019-09-12 09:08:23 +0100 | [diff] [blame] | 95 | paramsInfo, |
| 96 | indicesInfo, |
| 97 | outputInfo, |
| 98 | params, |
| 99 | indices, |
| 100 | expectedOutput); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 101 | } |
| 102 | |
Matthew Jackson | 9bff144 | 2019-09-12 09:08:23 +0100 | [diff] [blame] | 103 | static LayerTestResult<T, 2> GatherMultiDimParamsTestImpl( |
| 104 | armnn::IWorkloadFactory& workloadFactory, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 105 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
| 106 | const armnn::ITensorHandleFactory& tensorHandleFactory) |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 107 | { |
Matthew Jackson | 9bff144 | 2019-09-12 09:08:23 +0100 | [diff] [blame] | 108 | armnn::TensorInfo paramsInfo({ 5, 2 }, ArmnnType); |
| 109 | armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32); |
| 110 | armnn::TensorInfo outputInfo({ 3, 2 }, ArmnnType); |
| 111 | |
| 112 | if (armnn::IsQuantizedType<T>()) |
| 113 | { |
| 114 | paramsInfo.SetQuantizationScale(1.0f); |
| 115 | paramsInfo.SetQuantizationOffset(1); |
| 116 | outputInfo.SetQuantizationScale(1.0f); |
| 117 | outputInfo.SetQuantizationOffset(1); |
| 118 | } |
| 119 | |
| 120 | const std::vector<T> params = std::vector<T>({ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }); |
| 121 | const std::vector<int32_t> indices = std::vector<int32_t>({ 1, 3, 4 }); |
| 122 | const std::vector<T> expectedOutput = std::vector<T>({ 3, 4, 7, 8, 9, 10 }); |
| 123 | |
| 124 | return GatherTestImpl<ArmnnType, T, 2, 1, 2>( |
| 125 | workloadFactory, |
| 126 | memoryManager, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 127 | tensorHandleFactory, |
Matthew Jackson | 9bff144 | 2019-09-12 09:08:23 +0100 | [diff] [blame] | 128 | paramsInfo, |
| 129 | indicesInfo, |
| 130 | outputInfo, |
| 131 | params, |
| 132 | indices, |
| 133 | expectedOutput); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 134 | } |
| 135 | |
Matthew Jackson | 9bff144 | 2019-09-12 09:08:23 +0100 | [diff] [blame] | 136 | static LayerTestResult<T, 4> GatherMultiDimParamsMultiDimIndicesTestImpl( |
| 137 | armnn::IWorkloadFactory& workloadFactory, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 138 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
| 139 | const armnn::ITensorHandleFactory& tensorHandleFactory) |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 140 | { |
Matthew Jackson | 9bff144 | 2019-09-12 09:08:23 +0100 | [diff] [blame] | 141 | armnn::TensorInfo paramsInfo({ 3, 2, 3}, ArmnnType); |
| 142 | armnn::TensorInfo indicesInfo({ 2, 3 }, armnn::DataType::Signed32); |
| 143 | armnn::TensorInfo outputInfo({ 2, 3, 2, 3 }, ArmnnType); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 144 | |
Matthew Jackson | 9bff144 | 2019-09-12 09:08:23 +0100 | [diff] [blame] | 145 | if (armnn::IsQuantizedType<T>()) |
| 146 | { |
| 147 | paramsInfo.SetQuantizationScale(1.0f); |
| 148 | paramsInfo.SetQuantizationOffset(1); |
| 149 | outputInfo.SetQuantizationScale(1.0f); |
| 150 | outputInfo.SetQuantizationOffset(1); |
| 151 | } |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 152 | |
Matthew Jackson | 9bff144 | 2019-09-12 09:08:23 +0100 | [diff] [blame] | 153 | const std::vector<T> params = |
| 154 | { |
| 155 | 1, 2, 3, |
| 156 | 4, 5, 6, |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 157 | |
Matthew Jackson | 9bff144 | 2019-09-12 09:08:23 +0100 | [diff] [blame] | 158 | 7, 8, 9, |
| 159 | 10, 11, 12, |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 160 | |
Matthew Jackson | 9bff144 | 2019-09-12 09:08:23 +0100 | [diff] [blame] | 161 | 13, 14, 15, |
| 162 | 16, 17, 18 |
| 163 | }; |
| 164 | |
| 165 | const std::vector<int32_t> indices = { 1, 2, 1, 2, 1, 0 }; |
| 166 | |
| 167 | const std::vector<T> expectedOutput = |
| 168 | { |
| 169 | 7, 8, 9, |
| 170 | 10, 11, 12, |
| 171 | 13, 14, 15, |
| 172 | 16, 17, 18, |
| 173 | 7, 8, 9, |
| 174 | 10, 11, 12, |
| 175 | |
| 176 | 13, 14, 15, |
| 177 | 16, 17, 18, |
| 178 | 7, 8, 9, |
| 179 | 10, 11, 12, |
| 180 | 1, 2, 3, |
| 181 | 4, 5, 6 |
| 182 | }; |
| 183 | |
| 184 | return GatherTestImpl<ArmnnType, T, 3, 2, 4>( |
| 185 | workloadFactory, |
| 186 | memoryManager, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 187 | tensorHandleFactory, |
Matthew Jackson | 9bff144 | 2019-09-12 09:08:23 +0100 | [diff] [blame] | 188 | paramsInfo, |
| 189 | indicesInfo, |
| 190 | outputInfo, |
| 191 | params, |
| 192 | indices, |
| 193 | expectedOutput); |
| 194 | } |
| 195 | }; |
| 196 | |
| 197 | template<typename T> |
| 198 | struct GatherTestHelper<armnn::DataType::Float16, T> |
| 199 | { |
| 200 | static LayerTestResult<T, 1> Gather1dParamsTestImpl( |
| 201 | armnn::IWorkloadFactory& workloadFactory, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 202 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
| 203 | const armnn::ITensorHandleFactory& tensorHandleFactory) |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 204 | { |
Matthew Jackson | 9bff144 | 2019-09-12 09:08:23 +0100 | [diff] [blame] | 205 | using namespace half_float::literal; |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 206 | |
Matthew Jackson | 9bff144 | 2019-09-12 09:08:23 +0100 | [diff] [blame] | 207 | armnn::TensorInfo paramsInfo({ 8 }, armnn::DataType::Float16); |
| 208 | armnn::TensorInfo indicesInfo({ 4 }, armnn::DataType::Signed32); |
| 209 | armnn::TensorInfo outputInfo({ 4 }, armnn::DataType::Float16); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 210 | |
Matthew Jackson | 9bff144 | 2019-09-12 09:08:23 +0100 | [diff] [blame] | 211 | const std::vector<T> params = std::vector<T>({ 1._h, 2._h, 3._h, 4._h, 5._h, 6._h, 7._h, 8._h }); |
| 212 | const std::vector<int32_t> indices = std::vector<int32_t>({ 0, 2, 1, 5 }); |
| 213 | const std::vector<T> expectedOutput = std::vector<T>({ 1._h, 3._h, 2._h, 6._h }); |
| 214 | |
| 215 | return GatherTestImpl<armnn::DataType::Float16, T, 1, 1, 1>( |
| 216 | workloadFactory, |
| 217 | memoryManager, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 218 | tensorHandleFactory, |
Matthew Jackson | 9bff144 | 2019-09-12 09:08:23 +0100 | [diff] [blame] | 219 | paramsInfo, |
| 220 | indicesInfo, |
| 221 | outputInfo, |
| 222 | params, |
| 223 | indices, |
| 224 | expectedOutput); |
| 225 | } |
| 226 | |
| 227 | static LayerTestResult<T, 2> GatherMultiDimParamsTestImpl( |
| 228 | armnn::IWorkloadFactory& workloadFactory, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 229 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
| 230 | const armnn::ITensorHandleFactory& tensorHandleFactory) |
Matthew Jackson | 9bff144 | 2019-09-12 09:08:23 +0100 | [diff] [blame] | 231 | { |
| 232 | using namespace half_float::literal; |
| 233 | |
| 234 | armnn::TensorInfo paramsInfo({ 5, 2 }, armnn::DataType::Float16); |
| 235 | armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32); |
| 236 | armnn::TensorInfo outputInfo({ 3, 2 }, armnn::DataType::Float16); |
| 237 | |
| 238 | const std::vector<T> params = std::vector<T>({ 1._h, 2._h, 3._h, 4._h, 5._h, 6._h, 7._h, 8._h, 9._h, 10._h }); |
| 239 | |
| 240 | const std::vector<int32_t> indices = std::vector<int32_t>({ 1, 3, 4 }); |
| 241 | const std::vector<T> expectedOutput = std::vector<T>({ 3._h, 4._h, 7._h, 8._h, 9._h, 10._h }); |
| 242 | |
| 243 | return GatherTestImpl<armnn::DataType::Float16, T, 2, 1, 2>( |
| 244 | workloadFactory, |
| 245 | memoryManager, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 246 | tensorHandleFactory, |
Matthew Jackson | 9bff144 | 2019-09-12 09:08:23 +0100 | [diff] [blame] | 247 | paramsInfo, |
| 248 | indicesInfo, |
| 249 | outputInfo, |
| 250 | params, |
| 251 | indices, |
| 252 | expectedOutput); |
| 253 | } |
| 254 | |
| 255 | static LayerTestResult<T, 4> GatherMultiDimParamsMultiDimIndicesTestImpl( |
| 256 | armnn::IWorkloadFactory& workloadFactory, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 257 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
| 258 | const armnn::ITensorHandleFactory& tensorHandleFactory) |
Matthew Jackson | 9bff144 | 2019-09-12 09:08:23 +0100 | [diff] [blame] | 259 | { |
| 260 | using namespace half_float::literal; |
| 261 | |
| 262 | armnn::TensorInfo paramsInfo({ 3, 2, 3 }, armnn::DataType::Float16); |
| 263 | armnn::TensorInfo indicesInfo({ 2, 3 }, armnn::DataType::Signed32); |
| 264 | armnn::TensorInfo outputInfo({ 2, 3, 2, 3 }, armnn::DataType::Float16); |
| 265 | |
| 266 | const std::vector<T> params = |
| 267 | { |
| 268 | 1._h, 2._h, 3._h, |
| 269 | 4._h, 5._h, 6._h, |
| 270 | |
| 271 | 7._h, 8._h, 9._h, |
| 272 | 10._h, 11._h, 12._h, |
| 273 | |
| 274 | 13._h, 14._h, 15._h, |
| 275 | 16._h, 17._h, 18._h |
| 276 | }; |
| 277 | |
| 278 | const std::vector<int32_t> indices = { 1, 2, 1, 2, 1, 0 }; |
| 279 | |
| 280 | const std::vector<T> expectedOutput = |
| 281 | { |
| 282 | 7._h, 8._h, 9._h, |
| 283 | 10._h, 11._h, 12._h, |
| 284 | 13._h, 14._h, 15._h, |
| 285 | 16._h, 17._h, 18._h, |
| 286 | 7._h, 8._h, 9._h, |
| 287 | 10._h, 11._h, 12._h, |
| 288 | |
| 289 | 13._h, 14._h, 15._h, |
| 290 | 16._h, 17._h, 18._h, |
| 291 | 7._h, 8._h, 9._h, |
| 292 | 10._h, 11._h, 12._h, |
| 293 | 1._h, 2._h, 3._h, |
| 294 | 4._h, 5._h, 6._h |
| 295 | }; |
| 296 | |
| 297 | return GatherTestImpl<armnn::DataType::Float16, T, 3, 2, 4>( |
| 298 | workloadFactory, |
| 299 | memoryManager, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 300 | tensorHandleFactory, |
Matthew Jackson | 9bff144 | 2019-09-12 09:08:23 +0100 | [diff] [blame] | 301 | paramsInfo, |
| 302 | indicesInfo, |
| 303 | outputInfo, |
| 304 | params, |
| 305 | indices, |
| 306 | expectedOutput); |
| 307 | } |
| 308 | }; |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 309 | |
| 310 | } // anonymous namespace |
| 311 | |
Matthew Jackson | 9bff144 | 2019-09-12 09:08:23 +0100 | [diff] [blame] | 312 | LayerTestResult<float, 1> Gather1dParamsFloat32Test( |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 313 | armnn::IWorkloadFactory& workloadFactory, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 314 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
| 315 | const armnn::ITensorHandleFactory& tensorHandleFactory) |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 316 | { |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 317 | return GatherTestHelper<armnn::DataType::Float32>::Gather1dParamsTestImpl( |
| 318 | workloadFactory, memoryManager, tensorHandleFactory); |
Matthew Jackson | 9bff144 | 2019-09-12 09:08:23 +0100 | [diff] [blame] | 319 | } |
| 320 | |
| 321 | LayerTestResult<armnn::Half, 1> Gather1dParamsFloat16Test( |
| 322 | armnn::IWorkloadFactory& workloadFactory, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 323 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
| 324 | const armnn::ITensorHandleFactory& tensorHandleFactory) |
Matthew Jackson | 9bff144 | 2019-09-12 09:08:23 +0100 | [diff] [blame] | 325 | { |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 326 | return GatherTestHelper<armnn::DataType::Float16>::Gather1dParamsTestImpl( |
| 327 | workloadFactory, memoryManager, tensorHandleFactory); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 328 | } |
| 329 | |
| 330 | LayerTestResult<uint8_t, 1> Gather1dParamsUint8Test( |
| 331 | armnn::IWorkloadFactory& workloadFactory, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 332 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
| 333 | const armnn::ITensorHandleFactory& tensorHandleFactory) |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 334 | { |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 335 | return GatherTestHelper<armnn::DataType::QAsymmU8>::Gather1dParamsTestImpl( |
| 336 | workloadFactory, memoryManager, tensorHandleFactory); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 337 | } |
| 338 | |
| 339 | LayerTestResult<int16_t, 1> Gather1dParamsInt16Test( |
| 340 | armnn::IWorkloadFactory& workloadFactory, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 341 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
| 342 | const armnn::ITensorHandleFactory& tensorHandleFactory) |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 343 | { |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 344 | return GatherTestHelper<armnn::DataType::QSymmS16>::Gather1dParamsTestImpl( |
| 345 | workloadFactory, memoryManager, tensorHandleFactory); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 346 | } |
| 347 | |
Teresa Charlin | 9349246 | 2020-05-29 13:08:59 +0100 | [diff] [blame] | 348 | LayerTestResult<int32_t, 1> Gather1dParamsInt32Test( |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 349 | armnn::IWorkloadFactory& workloadFactory, |
| 350 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
| 351 | const armnn::ITensorHandleFactory& tensorHandleFactory) |
Teresa Charlin | 9349246 | 2020-05-29 13:08:59 +0100 | [diff] [blame] | 352 | { |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 353 | return GatherTestHelper<armnn::DataType::Signed32>::Gather1dParamsTestImpl( |
| 354 | workloadFactory, memoryManager, tensorHandleFactory); |
Teresa Charlin | 9349246 | 2020-05-29 13:08:59 +0100 | [diff] [blame] | 355 | } |
| 356 | |
Matthew Jackson | 9bff144 | 2019-09-12 09:08:23 +0100 | [diff] [blame] | 357 | LayerTestResult<float, 2> GatherMultiDimParamsFloat32Test( |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 358 | armnn::IWorkloadFactory& workloadFactory, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 359 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
| 360 | const armnn::ITensorHandleFactory& tensorHandleFactory) |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 361 | { |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 362 | return GatherTestHelper<armnn::DataType::Float32>::GatherMultiDimParamsTestImpl( |
| 363 | workloadFactory, memoryManager, tensorHandleFactory); |
Matthew Jackson | 9bff144 | 2019-09-12 09:08:23 +0100 | [diff] [blame] | 364 | } |
| 365 | |
| 366 | LayerTestResult<armnn::Half, 2> GatherMultiDimParamsFloat16Test( |
| 367 | armnn::IWorkloadFactory& workloadFactory, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 368 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
| 369 | const armnn::ITensorHandleFactory& tensorHandleFactory) |
Matthew Jackson | 9bff144 | 2019-09-12 09:08:23 +0100 | [diff] [blame] | 370 | { |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 371 | return GatherTestHelper<armnn::DataType::Float16>::GatherMultiDimParamsTestImpl( |
| 372 | workloadFactory, memoryManager, tensorHandleFactory); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 373 | } |
| 374 | |
| 375 | LayerTestResult<uint8_t, 2> GatherMultiDimParamsUint8Test( |
| 376 | armnn::IWorkloadFactory& workloadFactory, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 377 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
| 378 | const armnn::ITensorHandleFactory& tensorHandleFactory) |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 379 | { |
Derek Lamberti | f90c56d | 2020-01-10 17:14:08 +0000 | [diff] [blame] | 380 | return GatherTestHelper<armnn::DataType::QAsymmU8>::GatherMultiDimParamsTestImpl( |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 381 | workloadFactory, memoryManager, tensorHandleFactory); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 382 | } |
| 383 | |
| 384 | LayerTestResult<int16_t, 2> GatherMultiDimParamsInt16Test( |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 385 | armnn::IWorkloadFactory& workloadFactory, |
| 386 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
| 387 | const armnn::ITensorHandleFactory& tensorHandleFactory) |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 388 | { |
Derek Lamberti | f90c56d | 2020-01-10 17:14:08 +0000 | [diff] [blame] | 389 | return GatherTestHelper<armnn::DataType::QSymmS16>::GatherMultiDimParamsTestImpl( |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 390 | workloadFactory, memoryManager, tensorHandleFactory); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 391 | } |
| 392 | |
Teresa Charlin | 9349246 | 2020-05-29 13:08:59 +0100 | [diff] [blame] | 393 | LayerTestResult<int32_t, 2> GatherMultiDimParamsInt32Test( |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 394 | armnn::IWorkloadFactory& workloadFactory, |
| 395 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
| 396 | const armnn::ITensorHandleFactory& tensorHandleFactory) |
Teresa Charlin | 9349246 | 2020-05-29 13:08:59 +0100 | [diff] [blame] | 397 | { |
| 398 | return GatherTestHelper<armnn::DataType::Signed32>::GatherMultiDimParamsTestImpl( |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 399 | workloadFactory, memoryManager, tensorHandleFactory); |
Teresa Charlin | 9349246 | 2020-05-29 13:08:59 +0100 | [diff] [blame] | 400 | } |
| 401 | |
Matthew Jackson | 9bff144 | 2019-09-12 09:08:23 +0100 | [diff] [blame] | 402 | LayerTestResult<float, 4> GatherMultiDimParamsMultiDimIndicesFloat32Test( |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 403 | armnn::IWorkloadFactory& workloadFactory, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 404 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
| 405 | const armnn::ITensorHandleFactory& tensorHandleFactory) |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 406 | { |
Matthew Jackson | 9bff144 | 2019-09-12 09:08:23 +0100 | [diff] [blame] | 407 | return GatherTestHelper<armnn::DataType::Float32>::GatherMultiDimParamsMultiDimIndicesTestImpl( |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 408 | workloadFactory, memoryManager, tensorHandleFactory); |
Matthew Jackson | 9bff144 | 2019-09-12 09:08:23 +0100 | [diff] [blame] | 409 | } |
| 410 | |
| 411 | LayerTestResult<armnn::Half, 4> GatherMultiDimParamsMultiDimIndicesFloat16Test( |
| 412 | armnn::IWorkloadFactory& workloadFactory, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 413 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
| 414 | const armnn::ITensorHandleFactory& tensorHandleFactory) |
Matthew Jackson | 9bff144 | 2019-09-12 09:08:23 +0100 | [diff] [blame] | 415 | { |
| 416 | return GatherTestHelper<armnn::DataType::Float16>::GatherMultiDimParamsMultiDimIndicesTestImpl( |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 417 | workloadFactory, memoryManager, tensorHandleFactory); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 418 | } |
| 419 | |
| 420 | LayerTestResult<uint8_t, 4> GatherMultiDimParamsMultiDimIndicesUint8Test( |
| 421 | armnn::IWorkloadFactory& workloadFactory, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 422 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
| 423 | const armnn::ITensorHandleFactory& tensorHandleFactory) |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 424 | { |
Derek Lamberti | f90c56d | 2020-01-10 17:14:08 +0000 | [diff] [blame] | 425 | return GatherTestHelper<armnn::DataType::QAsymmU8>::GatherMultiDimParamsMultiDimIndicesTestImpl( |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 426 | workloadFactory, memoryManager, tensorHandleFactory); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 427 | } |
| 428 | |
| 429 | LayerTestResult<int16_t, 4> GatherMultiDimParamsMultiDimIndicesInt16Test( |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 430 | armnn::IWorkloadFactory& workloadFactory, |
| 431 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
| 432 | const armnn::ITensorHandleFactory& tensorHandleFactory) |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 433 | { |
Derek Lamberti | f90c56d | 2020-01-10 17:14:08 +0000 | [diff] [blame] | 434 | return GatherTestHelper<armnn::DataType::QSymmS16>::GatherMultiDimParamsMultiDimIndicesTestImpl( |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 435 | workloadFactory, memoryManager, tensorHandleFactory); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 436 | } |
Teresa Charlin | 9349246 | 2020-05-29 13:08:59 +0100 | [diff] [blame] | 437 | |
| 438 | LayerTestResult<int32_t, 4> GatherMultiDimParamsMultiDimIndicesInt32Test( |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 439 | armnn::IWorkloadFactory& workloadFactory, |
| 440 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
| 441 | const armnn::ITensorHandleFactory& tensorHandleFactory) |
Teresa Charlin | 9349246 | 2020-05-29 13:08:59 +0100 | [diff] [blame] | 442 | { |
| 443 | return GatherTestHelper<armnn::DataType::Signed32>::GatherMultiDimParamsMultiDimIndicesTestImpl( |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 444 | workloadFactory, memoryManager, tensorHandleFactory); |
Teresa Charlin | 9349246 | 2020-05-29 13:08:59 +0100 | [diff] [blame] | 445 | } |