blob: 82f94512c3591c2f58a6bcf91e03cd86e2cdd61b [file] [log] [blame]
narpra01db2b1602019-01-23 15:23:11 +00001//
Teresa Charlin52664732020-06-29 16:27:03 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
narpra01db2b1602019-01-23 15:23:11 +00003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Matteo Martincighf02e6cd2019-05-17 12:15:30 +01008#include "CommonTestUtils.hpp"
9
narpra01db2b1602019-01-23 15:23:11 +000010#include <armnn/INetwork.hpp>
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010011#include <ResolveType.hpp>
narpra01db2b1602019-01-23 15:23:11 +000012
13namespace{
14
15armnn::INetworkPtr CreateGatherNetwork(const armnn::TensorInfo& paramsInfo,
16 const armnn::TensorInfo& indicesInfo,
17 const armnn::TensorInfo& outputInfo,
18 const std::vector<int32_t>& indicesData)
19{
20 armnn::INetworkPtr net(armnn::INetwork::Create());
21
Teresa Charlin52664732020-06-29 16:27:03 +010022 armnn::GatherDescriptor descriptor;
narpra01db2b1602019-01-23 15:23:11 +000023 armnn::IConnectableLayer* paramsLayer = net->AddInputLayer(0);
24 armnn::IConnectableLayer* indicesLayer = net->AddConstantLayer(armnn::ConstTensor(indicesInfo, indicesData));
Teresa Charlin52664732020-06-29 16:27:03 +010025 armnn::IConnectableLayer* gatherLayer = net->AddGatherLayer(descriptor, "gather");
narpra01db2b1602019-01-23 15:23:11 +000026 armnn::IConnectableLayer* outputLayer = net->AddOutputLayer(0, "output");
27 Connect(paramsLayer, gatherLayer, paramsInfo, 0, 0);
28 Connect(indicesLayer, gatherLayer, indicesInfo, 0, 1);
29 Connect(gatherLayer, outputLayer, outputInfo, 0, 0);
30
31 return net;
32}
33
34template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
35void GatherEndToEnd(const std::vector<BackendId>& backends)
36{
37 armnn::TensorInfo paramsInfo({ 8 }, ArmnnType);
38 armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32);
39 armnn::TensorInfo outputInfo({ 3 }, ArmnnType);
40
41 paramsInfo.SetQuantizationScale(1.0f);
42 paramsInfo.SetQuantizationOffset(0);
43 outputInfo.SetQuantizationScale(1.0f);
44 outputInfo.SetQuantizationOffset(0);
45
46 // Creates structures for input & output.
47 std::vector<T> paramsData{
48 1, 2, 3, 4, 5, 6, 7, 8
49 };
50
51 std::vector<int32_t> indicesData{
52 7, 6, 5
53 };
54
55 std::vector<T> expectedOutput{
56 8, 7, 6
57 };
58
59 // Builds up the structure of the network
60 armnn::INetworkPtr net = CreateGatherNetwork(paramsInfo, indicesInfo, outputInfo, indicesData);
61
62 BOOST_TEST_CHECKPOINT("create a network");
63
64 std::map<int, std::vector<T>> inputTensorData = {{ 0, paramsData }};
65 std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
66
Nattapat Chaimanowong1fcb4ff2019-01-24 15:25:26 +000067 EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
narpra01db2b1602019-01-23 15:23:11 +000068}
69
70template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
71void GatherMultiDimEndToEnd(const std::vector<BackendId>& backends)
72{
73 armnn::TensorInfo paramsInfo({ 3, 2, 3}, ArmnnType);
74 armnn::TensorInfo indicesInfo({ 2, 3 }, armnn::DataType::Signed32);
75 armnn::TensorInfo outputInfo({ 2, 3, 2, 3 }, ArmnnType);
76
77 paramsInfo.SetQuantizationScale(1.0f);
78 paramsInfo.SetQuantizationOffset(0);
79 outputInfo.SetQuantizationScale(1.0f);
80 outputInfo.SetQuantizationOffset(0);
81
82 // Creates structures for input & output.
83 std::vector<T> paramsData{
84 1, 2, 3,
85 4, 5, 6,
86
87 7, 8, 9,
88 10, 11, 12,
89
90 13, 14, 15,
91 16, 17, 18
92 };
93
94 std::vector<int32_t> indicesData{
95 1, 2, 1,
96 2, 1, 0
97 };
98
99 std::vector<T> expectedOutput{
100 7, 8, 9,
101 10, 11, 12,
102 13, 14, 15,
103 16, 17, 18,
104 7, 8, 9,
105 10, 11, 12,
106
107 13, 14, 15,
108 16, 17, 18,
109 7, 8, 9,
110 10, 11, 12,
111 1, 2, 3,
112 4, 5, 6
113 };
114
115 // Builds up the structure of the network
116 armnn::INetworkPtr net = CreateGatherNetwork(paramsInfo, indicesInfo, outputInfo, indicesData);
117
118 BOOST_TEST_CHECKPOINT("create a network");
119
120 std::map<int, std::vector<T>> inputTensorData = {{ 0, paramsData }};
121 std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
122
Nattapat Chaimanowong1fcb4ff2019-01-24 15:25:26 +0000123 EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
narpra01db2b1602019-01-23 15:23:11 +0000124}
125
Nattapat Chaimanowong1fcb4ff2019-01-24 15:25:26 +0000126} // anonymous namespace