blob: d30da549df159b52d89cb0bd932c6bb4c23fd39b [file] [log] [blame]
narpra01db2b1602019-01-23 15:23:11 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/INetwork.hpp>
9#include <backendsCommon/test/CommonTestUtils.hpp>
10#include <TypeUtils.hpp>
11
12namespace{
13
14armnn::INetworkPtr CreateGatherNetwork(const armnn::TensorInfo& paramsInfo,
15 const armnn::TensorInfo& indicesInfo,
16 const armnn::TensorInfo& outputInfo,
17 const std::vector<int32_t>& indicesData)
18{
19 armnn::INetworkPtr net(armnn::INetwork::Create());
20
21 armnn::IConnectableLayer* paramsLayer = net->AddInputLayer(0);
22 armnn::IConnectableLayer* indicesLayer = net->AddConstantLayer(armnn::ConstTensor(indicesInfo, indicesData));
23 armnn::IConnectableLayer* gatherLayer = net->AddGatherLayer("gather");
24 armnn::IConnectableLayer* outputLayer = net->AddOutputLayer(0, "output");
25 Connect(paramsLayer, gatherLayer, paramsInfo, 0, 0);
26 Connect(indicesLayer, gatherLayer, indicesInfo, 0, 1);
27 Connect(gatherLayer, outputLayer, outputInfo, 0, 0);
28
29 return net;
30}
31
32template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
33void GatherEndToEnd(const std::vector<BackendId>& backends)
34{
35 armnn::TensorInfo paramsInfo({ 8 }, ArmnnType);
36 armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32);
37 armnn::TensorInfo outputInfo({ 3 }, ArmnnType);
38
39 paramsInfo.SetQuantizationScale(1.0f);
40 paramsInfo.SetQuantizationOffset(0);
41 outputInfo.SetQuantizationScale(1.0f);
42 outputInfo.SetQuantizationOffset(0);
43
44 // Creates structures for input & output.
45 std::vector<T> paramsData{
46 1, 2, 3, 4, 5, 6, 7, 8
47 };
48
49 std::vector<int32_t> indicesData{
50 7, 6, 5
51 };
52
53 std::vector<T> expectedOutput{
54 8, 7, 6
55 };
56
57 // Builds up the structure of the network
58 armnn::INetworkPtr net = CreateGatherNetwork(paramsInfo, indicesInfo, outputInfo, indicesData);
59
60 BOOST_TEST_CHECKPOINT("create a network");
61
62 std::map<int, std::vector<T>> inputTensorData = {{ 0, paramsData }};
63 std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
64
65 EndToEndLayerTestImpl<T>(move(net), inputTensorData, expectedOutputData, backends);
66}
67
68template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
69void GatherMultiDimEndToEnd(const std::vector<BackendId>& backends)
70{
71 armnn::TensorInfo paramsInfo({ 3, 2, 3}, ArmnnType);
72 armnn::TensorInfo indicesInfo({ 2, 3 }, armnn::DataType::Signed32);
73 armnn::TensorInfo outputInfo({ 2, 3, 2, 3 }, ArmnnType);
74
75 paramsInfo.SetQuantizationScale(1.0f);
76 paramsInfo.SetQuantizationOffset(0);
77 outputInfo.SetQuantizationScale(1.0f);
78 outputInfo.SetQuantizationOffset(0);
79
80 // Creates structures for input & output.
81 std::vector<T> paramsData{
82 1, 2, 3,
83 4, 5, 6,
84
85 7, 8, 9,
86 10, 11, 12,
87
88 13, 14, 15,
89 16, 17, 18
90 };
91
92 std::vector<int32_t> indicesData{
93 1, 2, 1,
94 2, 1, 0
95 };
96
97 std::vector<T> expectedOutput{
98 7, 8, 9,
99 10, 11, 12,
100 13, 14, 15,
101 16, 17, 18,
102 7, 8, 9,
103 10, 11, 12,
104
105 13, 14, 15,
106 16, 17, 18,
107 7, 8, 9,
108 10, 11, 12,
109 1, 2, 3,
110 4, 5, 6
111 };
112
113 // Builds up the structure of the network
114 armnn::INetworkPtr net = CreateGatherNetwork(paramsInfo, indicesInfo, outputInfo, indicesData);
115
116 BOOST_TEST_CHECKPOINT("create a network");
117
118 std::map<int, std::vector<T>> inputTensorData = {{ 0, paramsData }};
119 std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
120
121 EndToEndLayerTestImpl<T>(move(net), inputTensorData, expectedOutputData, backends);
122}
123
124} // anonymous namespace