blob: 0eea91190e93689868d0c560cb520fbdac126c75 [file] [log] [blame]
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <CommonTestUtils.hpp>
9
10#include <armnn/INetwork.hpp>
11#include <ResolveType.hpp>
12
13#include <doctest/doctest.h>
14
15namespace{
16
17armnn::INetworkPtr CreateGatherNdNetwork(const armnn::TensorInfo& paramsInfo,
18 const armnn::TensorInfo& indicesInfo,
19 const armnn::TensorInfo& outputInfo,
20 const std::vector<int32_t>& indicesData)
21{
22 armnn::INetworkPtr net(armnn::INetwork::Create());
23
24 armnn::IConnectableLayer* paramsLayer = net->AddInputLayer(0);
25 armnn::IConnectableLayer* indicesLayer = net->AddConstantLayer(armnn::ConstTensor(indicesInfo, indicesData));
26 armnn::IConnectableLayer* gatherNdLayer = net->AddGatherNdLayer("gatherNd");
27 armnn::IConnectableLayer* outputLayer = net->AddOutputLayer(0, "output");
28 Connect(paramsLayer, gatherNdLayer, paramsInfo, 0, 0);
29 Connect(indicesLayer, gatherNdLayer, indicesInfo, 0, 1);
30 Connect(gatherNdLayer, outputLayer, outputInfo, 0, 0);
31
32 return net;
33}
34
35template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
36void GatherNdEndToEnd(const std::vector<BackendId>& backends)
37{
38 armnn::TensorInfo paramsInfo({ 2, 3, 8, 4 }, ArmnnType);
39 armnn::TensorInfo indicesInfo({ 2, 2 }, armnn::DataType::Signed32);
40 armnn::TensorInfo outputInfo({ 2, 8, 4 }, ArmnnType);
41
42 paramsInfo.SetQuantizationScale(1.0f);
43 paramsInfo.SetQuantizationOffset(0);
44 paramsInfo.SetConstant(true);
45 indicesInfo.SetConstant(true);
46 outputInfo.SetQuantizationScale(1.0f);
47 outputInfo.SetQuantizationOffset(0);
48
49 // Creates structures for input & output.
50 std::vector<T> paramsData{
51 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
52 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
53
54 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
55 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
56
57 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79,
58 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95,
59
60 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
61 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127,
62
63 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143,
64 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159,
65
66 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175,
67 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191
68 };
69
70 std::vector<int32_t> indicesData{
71 { 1, 2, 1, 1},
72 };
73
74 std::vector<T> expectedOutput{
75 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175,
76 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191,
77
78 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143,
79 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159
80 };
81
82 // Builds up the structure of the network
83 armnn::INetworkPtr net = CreateGatherNdNetwork(paramsInfo, indicesInfo, outputInfo, indicesData);
84
85 CHECK(net);
86
87 std::map<int, std::vector<T>> inputTensorData = {{ 0, paramsData }};
88 std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
89
90 EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
91}
92
93template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
94void GatherNdMultiDimEndToEnd(const std::vector<BackendId>& backends)
95{
96 armnn::TensorInfo paramsInfo({ 5, 5, 2 }, ArmnnType);
97 armnn::TensorInfo indicesInfo({ 2, 2, 3, 2 }, armnn::DataType::Signed32);
98 armnn::TensorInfo outputInfo({ 2, 2, 3, 2 }, ArmnnType);
99
100 paramsInfo.SetQuantizationScale(1.0f);
101 paramsInfo.SetQuantizationOffset(0);
102 paramsInfo.SetConstant(true);
103 indicesInfo.SetConstant(true);
104 outputInfo.SetQuantizationScale(1.0f);
105 outputInfo.SetQuantizationOffset(0);
106
107 // Creates structures for input & output.
108 std::vector<T> paramsData{
109 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
110 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
111 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
112 30, 31, 32, 33, 34, 35, 36, 37, 38, 39,
113 40, 41, 42, 43, 44, 45, 46, 47, 48, 49
114 };
115
116 std::vector<int32_t> indicesData{
117 0, 0,
118 3, 3,
119 4, 4,
120
121 0, 0,
122 1, 1,
123 2, 2,
124
125 4, 4,
126 3, 3,
127 0, 0,
128
129 2, 2,
130 1, 1,
131 0, 0
132 };
133
134 std::vector<T> expectedOutput{
135 0, 1,
136 36, 37,
137 48, 49,
138
139 0, 1,
140 12, 13,
141 24, 25,
142
143 48, 49,
144 36, 37,
145 0, 1,
146
147 24, 25,
148 12, 13,
149 0, 1
150 };
151
152 // Builds up the structure of the network
153 armnn::INetworkPtr net = CreateGatherNdNetwork(paramsInfo, indicesInfo, outputInfo, indicesData);
154
155 std::map<int, std::vector<T>> inputTensorData = {{ 0, paramsData }};
156 std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
157
158 EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
159}
160
161} // anonymous namespace