blob: 6adaa5bd70b86bcd36d0bdb40fb842eea175d834 [file] [log] [blame]
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01001//
Mike Kellya9c32672023-12-04 17:23:09 +00002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <CommonTestUtils.hpp>
9
10#include <armnn/INetwork.hpp>
11#include <ResolveType.hpp>
12
13#include <doctest/doctest.h>
14
15namespace{
16
17armnn::INetworkPtr CreateGatherNdNetwork(const armnn::TensorInfo& paramsInfo,
18 const armnn::TensorInfo& indicesInfo,
19 const armnn::TensorInfo& outputInfo,
20 const std::vector<int32_t>& indicesData)
21{
22 armnn::INetworkPtr net(armnn::INetwork::Create());
23
24 armnn::IConnectableLayer* paramsLayer = net->AddInputLayer(0);
25 armnn::IConnectableLayer* indicesLayer = net->AddConstantLayer(armnn::ConstTensor(indicesInfo, indicesData));
26 armnn::IConnectableLayer* gatherNdLayer = net->AddGatherNdLayer("gatherNd");
27 armnn::IConnectableLayer* outputLayer = net->AddOutputLayer(0, "output");
28 Connect(paramsLayer, gatherNdLayer, paramsInfo, 0, 0);
29 Connect(indicesLayer, gatherNdLayer, indicesInfo, 0, 1);
30 Connect(gatherNdLayer, outputLayer, outputInfo, 0, 0);
31
32 return net;
33}
34
35template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
36void GatherNdEndToEnd(const std::vector<BackendId>& backends)
37{
38 armnn::TensorInfo paramsInfo({ 2, 3, 8, 4 }, ArmnnType);
39 armnn::TensorInfo indicesInfo({ 2, 2 }, armnn::DataType::Signed32);
40 armnn::TensorInfo outputInfo({ 2, 8, 4 }, ArmnnType);
41
42 paramsInfo.SetQuantizationScale(1.0f);
43 paramsInfo.SetQuantizationOffset(0);
44 paramsInfo.SetConstant(true);
45 indicesInfo.SetConstant(true);
46 outputInfo.SetQuantizationScale(1.0f);
47 outputInfo.SetQuantizationOffset(0);
48
49 // Creates structures for input & output.
50 std::vector<T> paramsData{
51 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
52 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
53
54 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
55 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
56
57 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79,
58 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95,
59
60 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
61 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127,
62
63 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143,
64 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159,
65
66 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175,
67 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191
68 };
69
70 std::vector<int32_t> indicesData{
71 { 1, 2, 1, 1},
72 };
73
74 std::vector<T> expectedOutput{
75 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175,
76 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191,
77
78 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143,
79 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159
80 };
81
82 // Builds up the structure of the network
83 armnn::INetworkPtr net = CreateGatherNdNetwork(paramsInfo, indicesInfo, outputInfo, indicesData);
84
85 CHECK(net);
86
87 std::map<int, std::vector<T>> inputTensorData = {{ 0, paramsData }};
88 std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
89
Mike Kellya9c32672023-12-04 17:23:09 +000090 EndToEndLayerTestImpl<ArmnnType, ArmnnType>(std::move(net), inputTensorData, expectedOutputData, backends);
Teresa Charlinb2d3ec52022-04-12 22:07:09 +010091}
92
93template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
94void GatherNdMultiDimEndToEnd(const std::vector<BackendId>& backends)
95{
96 armnn::TensorInfo paramsInfo({ 5, 5, 2 }, ArmnnType);
97 armnn::TensorInfo indicesInfo({ 2, 2, 3, 2 }, armnn::DataType::Signed32);
98 armnn::TensorInfo outputInfo({ 2, 2, 3, 2 }, ArmnnType);
99
100 paramsInfo.SetQuantizationScale(1.0f);
101 paramsInfo.SetQuantizationOffset(0);
102 paramsInfo.SetConstant(true);
103 indicesInfo.SetConstant(true);
104 outputInfo.SetQuantizationScale(1.0f);
105 outputInfo.SetQuantizationOffset(0);
106
107 // Creates structures for input & output.
108 std::vector<T> paramsData{
109 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
110 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
111 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
112 30, 31, 32, 33, 34, 35, 36, 37, 38, 39,
113 40, 41, 42, 43, 44, 45, 46, 47, 48, 49
114 };
115
116 std::vector<int32_t> indicesData{
117 0, 0,
118 3, 3,
119 4, 4,
120
121 0, 0,
122 1, 1,
123 2, 2,
124
125 4, 4,
126 3, 3,
127 0, 0,
128
129 2, 2,
130 1, 1,
131 0, 0
132 };
133
134 std::vector<T> expectedOutput{
135 0, 1,
136 36, 37,
137 48, 49,
138
139 0, 1,
140 12, 13,
141 24, 25,
142
143 48, 49,
144 36, 37,
145 0, 1,
146
147 24, 25,
148 12, 13,
149 0, 1
150 };
151
152 // Builds up the structure of the network
153 armnn::INetworkPtr net = CreateGatherNdNetwork(paramsInfo, indicesInfo, outputInfo, indicesData);
154
155 std::map<int, std::vector<T>> inputTensorData = {{ 0, paramsData }};
156 std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
157
Mike Kellya9c32672023-12-04 17:23:09 +0000158 EndToEndLayerTestImpl<ArmnnType, ArmnnType>(std::move(net), inputTensorData, expectedOutputData, backends);
Teresa Charlinb2d3ec52022-04-12 22:07:09 +0100159}
160
161} // anonymous namespace