blob: 431ef31437afd4a3148953c02c30cfa5729fa805 [file] [log] [blame]
narpra01db2b1602019-01-23 15:23:11 +00001//
Teresa Charlin52664732020-06-29 16:27:03 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
narpra01db2b1602019-01-23 15:23:11 +00003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Matteo Martincighf02e6cd2019-05-17 12:15:30 +01008#include "CommonTestUtils.hpp"
9
narpra01db2b1602019-01-23 15:23:11 +000010#include <armnn/INetwork.hpp>
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010011#include <ResolveType.hpp>
narpra01db2b1602019-01-23 15:23:11 +000012
Sadik Armagan1625efc2021-06-10 18:24:34 +010013#include <doctest/doctest.h>
14
narpra01db2b1602019-01-23 15:23:11 +000015namespace{
16
17armnn::INetworkPtr CreateGatherNetwork(const armnn::TensorInfo& paramsInfo,
18 const armnn::TensorInfo& indicesInfo,
19 const armnn::TensorInfo& outputInfo,
20 const std::vector<int32_t>& indicesData)
21{
22 armnn::INetworkPtr net(armnn::INetwork::Create());
23
Teresa Charlin52664732020-06-29 16:27:03 +010024 armnn::GatherDescriptor descriptor;
narpra01db2b1602019-01-23 15:23:11 +000025 armnn::IConnectableLayer* paramsLayer = net->AddInputLayer(0);
26 armnn::IConnectableLayer* indicesLayer = net->AddConstantLayer(armnn::ConstTensor(indicesInfo, indicesData));
Teresa Charlin52664732020-06-29 16:27:03 +010027 armnn::IConnectableLayer* gatherLayer = net->AddGatherLayer(descriptor, "gather");
narpra01db2b1602019-01-23 15:23:11 +000028 armnn::IConnectableLayer* outputLayer = net->AddOutputLayer(0, "output");
29 Connect(paramsLayer, gatherLayer, paramsInfo, 0, 0);
30 Connect(indicesLayer, gatherLayer, indicesInfo, 0, 1);
31 Connect(gatherLayer, outputLayer, outputInfo, 0, 0);
32
33 return net;
34}
35
36template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
37void GatherEndToEnd(const std::vector<BackendId>& backends)
38{
39 armnn::TensorInfo paramsInfo({ 8 }, ArmnnType);
40 armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32);
41 armnn::TensorInfo outputInfo({ 3 }, ArmnnType);
42
43 paramsInfo.SetQuantizationScale(1.0f);
44 paramsInfo.SetQuantizationOffset(0);
45 outputInfo.SetQuantizationScale(1.0f);
46 outputInfo.SetQuantizationOffset(0);
47
48 // Creates structures for input & output.
49 std::vector<T> paramsData{
50 1, 2, 3, 4, 5, 6, 7, 8
51 };
52
53 std::vector<int32_t> indicesData{
54 7, 6, 5
55 };
56
57 std::vector<T> expectedOutput{
58 8, 7, 6
59 };
60
61 // Builds up the structure of the network
62 armnn::INetworkPtr net = CreateGatherNetwork(paramsInfo, indicesInfo, outputInfo, indicesData);
63
Sadik Armagan1625efc2021-06-10 18:24:34 +010064 CHECK(net);
narpra01db2b1602019-01-23 15:23:11 +000065
66 std::map<int, std::vector<T>> inputTensorData = {{ 0, paramsData }};
67 std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
68
Nattapat Chaimanowong1fcb4ff2019-01-24 15:25:26 +000069 EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
narpra01db2b1602019-01-23 15:23:11 +000070}
71
72template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
73void GatherMultiDimEndToEnd(const std::vector<BackendId>& backends)
74{
75 armnn::TensorInfo paramsInfo({ 3, 2, 3}, ArmnnType);
76 armnn::TensorInfo indicesInfo({ 2, 3 }, armnn::DataType::Signed32);
77 armnn::TensorInfo outputInfo({ 2, 3, 2, 3 }, ArmnnType);
78
79 paramsInfo.SetQuantizationScale(1.0f);
80 paramsInfo.SetQuantizationOffset(0);
81 outputInfo.SetQuantizationScale(1.0f);
82 outputInfo.SetQuantizationOffset(0);
83
84 // Creates structures for input & output.
85 std::vector<T> paramsData{
86 1, 2, 3,
87 4, 5, 6,
88
89 7, 8, 9,
90 10, 11, 12,
91
92 13, 14, 15,
93 16, 17, 18
94 };
95
96 std::vector<int32_t> indicesData{
97 1, 2, 1,
98 2, 1, 0
99 };
100
101 std::vector<T> expectedOutput{
102 7, 8, 9,
103 10, 11, 12,
104 13, 14, 15,
105 16, 17, 18,
106 7, 8, 9,
107 10, 11, 12,
108
109 13, 14, 15,
110 16, 17, 18,
111 7, 8, 9,
112 10, 11, 12,
113 1, 2, 3,
114 4, 5, 6
115 };
116
117 // Builds up the structure of the network
118 armnn::INetworkPtr net = CreateGatherNetwork(paramsInfo, indicesInfo, outputInfo, indicesData);
119
narpra01db2b1602019-01-23 15:23:11 +0000120 std::map<int, std::vector<T>> inputTensorData = {{ 0, paramsData }};
121 std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
122
Nattapat Chaimanowong1fcb4ff2019-01-24 15:25:26 +0000123 EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
narpra01db2b1602019-01-23 15:23:11 +0000124}
125
Nattapat Chaimanowong1fcb4ff2019-01-24 15:25:26 +0000126} // anonymous namespace