blob: f0341e24eeec0f7eab90a66c9b37a62352493dde [file] [log] [blame]
Teresa Charlin6966bfa2022-04-25 17:14:50 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ParserFlatbuffersSerializeFixture.hpp"
7#include <armnnDeserializer/IDeserializer.hpp>
8
9#include <string>
10
11TEST_SUITE("Deserializer_GatherNd")
12{
13struct GatherNdFixture : public ParserFlatbuffersSerializeFixture
14{
15 explicit GatherNdFixture(const std::string& paramsShape,
16 const std::string& indicesShape,
17 const std::string& outputShape,
18 const std::string& indicesData,
19 const std::string dataType,
20 const std::string constDataType)
21 {
22 m_JsonString = R"(
23 {
24 inputIds: [0],
25 outputIds: [3],
26 layers: [
27 {
28 layer_type: "InputLayer",
29 layer: {
30 base: {
31 layerBindingId: 0,
32 base: {
33 index: 0,
34 layerName: "InputLayer",
35 layerType: "Input",
36 inputSlots: [{
37 index: 0,
38 connection: {sourceLayerIndex:0, outputSlotIndex:0 },
39 }],
40 outputSlots: [ {
41 index: 0,
42 tensorInfo: {
43 dimensions: )" + paramsShape + R"(,
44 dataType: )" + dataType + R"(
45 }}]
46 }
47 }}},
48 {
49 layer_type: "ConstantLayer",
50 layer: {
51 base: {
52 index:1,
53 layerName: "ConstantLayer",
54 layerType: "Constant",
55 outputSlots: [ {
56 index: 0,
57 tensorInfo: {
58 dimensions: )" + indicesShape + R"(,
59 dataType: "Signed32",
60 },
61 }],
62 },
63 input: {
64 info: {
65 dimensions: )" + indicesShape + R"(,
66 dataType: )" + dataType + R"(
67 },
68 data_type: )" + constDataType + R"(,
69 data: {
70 data: )" + indicesData + R"(,
71 } }
72 },},
73 {
74 layer_type: "GatherNdLayer",
75 layer: {
76 base: {
77 index: 2,
78 layerName: "GatherNdLayer",
79 layerType: "GatherNd",
80 inputSlots: [
81 {
82 index: 0,
83 connection: {sourceLayerIndex:0, outputSlotIndex:0 },
84 },
85 {
86 index: 1,
87 connection: {sourceLayerIndex:1, outputSlotIndex:0 }
88 }],
89 outputSlots: [ {
90 index: 0,
91 tensorInfo: {
92 dimensions: )" + outputShape + R"(,
93 dataType: )" + dataType + R"(
94
95 }}]},
96 }},
97 {
98 layer_type: "OutputLayer",
99 layer: {
100 base:{
101 layerBindingId: 0,
102 base: {
103 index: 3,
104 layerName: "OutputLayer",
105 layerType: "Output",
106 inputSlots: [{
107 index: 0,
108 connection: {sourceLayerIndex:2, outputSlotIndex:0 },
109 }],
110 outputSlots: [ {
111 index: 0,
112 tensorInfo: {
113 dimensions: )" + outputShape + R"(,
114 dataType: )" + dataType + R"(
115 },
116 }],
117 }}},
Cathal Corbett06902652022-04-14 17:55:11 +0100118 }],
119 featureVersions: {
120 weightsLayoutScheme: 1,
121 }
Teresa Charlin6966bfa2022-04-25 17:14:50 +0100122 } )";
123
124 Setup();
125 }
126};
127
128struct SimpleGatherNdFixtureFloat32 : GatherNdFixture
129{
130 SimpleGatherNdFixtureFloat32() : GatherNdFixture("[ 6, 3 ]", "[ 3, 1 ]", "[ 3, 3 ]",
131 "[ 5, 1, 0 ]", "Float32", "IntData") {}
132};
133
134TEST_CASE_FIXTURE(SimpleGatherNdFixtureFloat32, "GatherNdFloat32")
135{
136 RunTest<4, armnn::DataType::Float32>(0,
137 {{"InputLayer", { 1, 2, 3,
138 4, 5, 6,
139 7, 8, 9,
140 10, 11, 12,
141 13, 14, 15,
142 16, 17, 18 }}},
143 {{"OutputLayer", { 16, 17, 18,
144 4, 5, 6,
145 1, 2, 3}}});
146}
147
148}
149