blob: 1c97bef467e07c396257806ecc70d2b6d0870d10 [file] [log] [blame]
narpra01db2b1602019-01-23 15:23:11 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Matteo Martincighf02e6cd2019-05-17 12:15:30 +01008#include "CommonTestUtils.hpp"
9
narpra01db2b1602019-01-23 15:23:11 +000010#include <armnn/INetwork.hpp>
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010011#include <ResolveType.hpp>
narpra01db2b1602019-01-23 15:23:11 +000012
13namespace{
14
15armnn::INetworkPtr CreateGatherNetwork(const armnn::TensorInfo& paramsInfo,
16 const armnn::TensorInfo& indicesInfo,
17 const armnn::TensorInfo& outputInfo,
18 const std::vector<int32_t>& indicesData)
19{
20 armnn::INetworkPtr net(armnn::INetwork::Create());
21
22 armnn::IConnectableLayer* paramsLayer = net->AddInputLayer(0);
23 armnn::IConnectableLayer* indicesLayer = net->AddConstantLayer(armnn::ConstTensor(indicesInfo, indicesData));
24 armnn::IConnectableLayer* gatherLayer = net->AddGatherLayer("gather");
25 armnn::IConnectableLayer* outputLayer = net->AddOutputLayer(0, "output");
26 Connect(paramsLayer, gatherLayer, paramsInfo, 0, 0);
27 Connect(indicesLayer, gatherLayer, indicesInfo, 0, 1);
28 Connect(gatherLayer, outputLayer, outputInfo, 0, 0);
29
30 return net;
31}
32
33template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
34void GatherEndToEnd(const std::vector<BackendId>& backends)
35{
36 armnn::TensorInfo paramsInfo({ 8 }, ArmnnType);
37 armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32);
38 armnn::TensorInfo outputInfo({ 3 }, ArmnnType);
39
40 paramsInfo.SetQuantizationScale(1.0f);
41 paramsInfo.SetQuantizationOffset(0);
42 outputInfo.SetQuantizationScale(1.0f);
43 outputInfo.SetQuantizationOffset(0);
44
45 // Creates structures for input & output.
46 std::vector<T> paramsData{
47 1, 2, 3, 4, 5, 6, 7, 8
48 };
49
50 std::vector<int32_t> indicesData{
51 7, 6, 5
52 };
53
54 std::vector<T> expectedOutput{
55 8, 7, 6
56 };
57
58 // Builds up the structure of the network
59 armnn::INetworkPtr net = CreateGatherNetwork(paramsInfo, indicesInfo, outputInfo, indicesData);
60
61 BOOST_TEST_CHECKPOINT("create a network");
62
63 std::map<int, std::vector<T>> inputTensorData = {{ 0, paramsData }};
64 std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
65
Nattapat Chaimanowong1fcb4ff2019-01-24 15:25:26 +000066 EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
narpra01db2b1602019-01-23 15:23:11 +000067}
68
69template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
70void GatherMultiDimEndToEnd(const std::vector<BackendId>& backends)
71{
72 armnn::TensorInfo paramsInfo({ 3, 2, 3}, ArmnnType);
73 armnn::TensorInfo indicesInfo({ 2, 3 }, armnn::DataType::Signed32);
74 armnn::TensorInfo outputInfo({ 2, 3, 2, 3 }, ArmnnType);
75
76 paramsInfo.SetQuantizationScale(1.0f);
77 paramsInfo.SetQuantizationOffset(0);
78 outputInfo.SetQuantizationScale(1.0f);
79 outputInfo.SetQuantizationOffset(0);
80
81 // Creates structures for input & output.
82 std::vector<T> paramsData{
83 1, 2, 3,
84 4, 5, 6,
85
86 7, 8, 9,
87 10, 11, 12,
88
89 13, 14, 15,
90 16, 17, 18
91 };
92
93 std::vector<int32_t> indicesData{
94 1, 2, 1,
95 2, 1, 0
96 };
97
98 std::vector<T> expectedOutput{
99 7, 8, 9,
100 10, 11, 12,
101 13, 14, 15,
102 16, 17, 18,
103 7, 8, 9,
104 10, 11, 12,
105
106 13, 14, 15,
107 16, 17, 18,
108 7, 8, 9,
109 10, 11, 12,
110 1, 2, 3,
111 4, 5, 6
112 };
113
114 // Builds up the structure of the network
115 armnn::INetworkPtr net = CreateGatherNetwork(paramsInfo, indicesInfo, outputInfo, indicesData);
116
117 BOOST_TEST_CHECKPOINT("create a network");
118
119 std::map<int, std::vector<T>> inputTensorData = {{ 0, paramsData }};
120 std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
121
Nattapat Chaimanowong1fcb4ff2019-01-24 15:25:26 +0000122 EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
narpra01db2b1602019-01-23 15:23:11 +0000123}
124
Nattapat Chaimanowong1fcb4ff2019-01-24 15:25:26 +0000125} // anonymous namespace