blob: cf4294780dd78c56b5d6f1c5927136e19c057d6d [file] [log] [blame]
narpra01db2b1602019-01-23 15:23:11 +00001//
Teresa Charlin52664732020-06-29 16:27:03 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
narpra01db2b1602019-01-23 15:23:11 +00003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Sadik Armagana097d2a2021-11-24 15:47:28 +00008#include <CommonTestUtils.hpp>
Matteo Martincighf02e6cd2019-05-17 12:15:30 +01009
narpra01db2b1602019-01-23 15:23:11 +000010#include <armnn/INetwork.hpp>
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010011#include <ResolveType.hpp>
narpra01db2b1602019-01-23 15:23:11 +000012
Sadik Armagan1625efc2021-06-10 18:24:34 +010013#include <doctest/doctest.h>
14
narpra01db2b1602019-01-23 15:23:11 +000015namespace{
16
17armnn::INetworkPtr CreateGatherNetwork(const armnn::TensorInfo& paramsInfo,
18 const armnn::TensorInfo& indicesInfo,
19 const armnn::TensorInfo& outputInfo,
20 const std::vector<int32_t>& indicesData)
21{
22 armnn::INetworkPtr net(armnn::INetwork::Create());
23
Teresa Charlin52664732020-06-29 16:27:03 +010024 armnn::GatherDescriptor descriptor;
narpra01db2b1602019-01-23 15:23:11 +000025 armnn::IConnectableLayer* paramsLayer = net->AddInputLayer(0);
26 armnn::IConnectableLayer* indicesLayer = net->AddConstantLayer(armnn::ConstTensor(indicesInfo, indicesData));
Teresa Charlin52664732020-06-29 16:27:03 +010027 armnn::IConnectableLayer* gatherLayer = net->AddGatherLayer(descriptor, "gather");
narpra01db2b1602019-01-23 15:23:11 +000028 armnn::IConnectableLayer* outputLayer = net->AddOutputLayer(0, "output");
29 Connect(paramsLayer, gatherLayer, paramsInfo, 0, 0);
30 Connect(indicesLayer, gatherLayer, indicesInfo, 0, 1);
31 Connect(gatherLayer, outputLayer, outputInfo, 0, 0);
32
33 return net;
34}
35
36template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
37void GatherEndToEnd(const std::vector<BackendId>& backends)
38{
39 armnn::TensorInfo paramsInfo({ 8 }, ArmnnType);
40 armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32);
41 armnn::TensorInfo outputInfo({ 3 }, ArmnnType);
42
43 paramsInfo.SetQuantizationScale(1.0f);
44 paramsInfo.SetQuantizationOffset(0);
Cathal Corbett5b8093c2021-10-22 11:12:07 +010045 paramsInfo.SetConstant(true);
46 indicesInfo.SetConstant(true);
narpra01db2b1602019-01-23 15:23:11 +000047 outputInfo.SetQuantizationScale(1.0f);
48 outputInfo.SetQuantizationOffset(0);
49
50 // Creates structures for input & output.
51 std::vector<T> paramsData{
52 1, 2, 3, 4, 5, 6, 7, 8
53 };
54
55 std::vector<int32_t> indicesData{
56 7, 6, 5
57 };
58
59 std::vector<T> expectedOutput{
60 8, 7, 6
61 };
62
63 // Builds up the structure of the network
64 armnn::INetworkPtr net = CreateGatherNetwork(paramsInfo, indicesInfo, outputInfo, indicesData);
65
Sadik Armagan1625efc2021-06-10 18:24:34 +010066 CHECK(net);
narpra01db2b1602019-01-23 15:23:11 +000067
68 std::map<int, std::vector<T>> inputTensorData = {{ 0, paramsData }};
69 std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
70
Nattapat Chaimanowong1fcb4ff2019-01-24 15:25:26 +000071 EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
narpra01db2b1602019-01-23 15:23:11 +000072}
73
74template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
75void GatherMultiDimEndToEnd(const std::vector<BackendId>& backends)
76{
77 armnn::TensorInfo paramsInfo({ 3, 2, 3}, ArmnnType);
78 armnn::TensorInfo indicesInfo({ 2, 3 }, armnn::DataType::Signed32);
79 armnn::TensorInfo outputInfo({ 2, 3, 2, 3 }, ArmnnType);
80
81 paramsInfo.SetQuantizationScale(1.0f);
82 paramsInfo.SetQuantizationOffset(0);
Cathal Corbett5b8093c2021-10-22 11:12:07 +010083 paramsInfo.SetConstant(true);
84 indicesInfo.SetConstant(true);
narpra01db2b1602019-01-23 15:23:11 +000085 outputInfo.SetQuantizationScale(1.0f);
86 outputInfo.SetQuantizationOffset(0);
87
88 // Creates structures for input & output.
89 std::vector<T> paramsData{
90 1, 2, 3,
91 4, 5, 6,
92
93 7, 8, 9,
94 10, 11, 12,
95
96 13, 14, 15,
97 16, 17, 18
98 };
99
100 std::vector<int32_t> indicesData{
101 1, 2, 1,
102 2, 1, 0
103 };
104
105 std::vector<T> expectedOutput{
106 7, 8, 9,
107 10, 11, 12,
108 13, 14, 15,
109 16, 17, 18,
110 7, 8, 9,
111 10, 11, 12,
112
113 13, 14, 15,
114 16, 17, 18,
115 7, 8, 9,
116 10, 11, 12,
117 1, 2, 3,
118 4, 5, 6
119 };
120
121 // Builds up the structure of the network
122 armnn::INetworkPtr net = CreateGatherNetwork(paramsInfo, indicesInfo, outputInfo, indicesData);
123
narpra01db2b1602019-01-23 15:23:11 +0000124 std::map<int, std::vector<T>> inputTensorData = {{ 0, paramsData }};
125 std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
126
Nattapat Chaimanowong1fcb4ff2019-01-24 15:25:26 +0000127 EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
narpra01db2b1602019-01-23 15:23:11 +0000128}
129
Nattapat Chaimanowong1fcb4ff2019-01-24 15:25:26 +0000130} // anonymous namespace