blob: f40dc57556a75132b3cbd08f8bd6de5dd9ac8368 [file] [log] [blame]
FrancisMurtagh94412af2019-01-24 10:53:39 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "armnnTfParser/ITfParser.hpp"
7
8#include "ParserPrototxtFixture.hpp"
9#include <PrototxtConversions.hpp>
10
11#include <boost/test/unit_test.hpp>
12
13BOOST_AUTO_TEST_SUITE(TensorflowParser)
14
15// helper for setting the dimensions in prototxt
16void dimsHelper(const std::vector<int>& dims, std::string& text){
17 for(u_int i=0; i<dims.size(); ++i){
18 text.append(R"(dim {
19 size: )");
20 text.append(std::to_string(dims[i]));
21 text.append(R"(
22 })");
23 }
24}
25
26// helper for converting from integer to octal representation
27void octalHelper(const std::vector<int>& indicesContent, std::string& text){
28 for (unsigned int i = 0; i < indicesContent.size(); ++i)
29 {
30 text.append(armnnUtils::ConvertInt32ToOctalString(static_cast<int>(indicesContent[i])));
31 }
32}
33
34struct GatherFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
35{
36 GatherFixture(const armnn::TensorShape& inputShape0,
37 const armnn::TensorShape& inputShape1,
38 const std::vector<int>& input1Content,
39 const std::vector<int>& input0Dims,
40 const std::vector<int>& input1Dims)
41 {
42 m_Prototext = R"(
43node {
44 name: "input0"
45 op: "Placeholder"
46 attr {
47 key: "dtype"
48 value {
49 type: DT_FLOAT
50 }
51 }
52 attr {
53 key: "shape"
54 value {
55 shape {
56)";
57 dimsHelper(input0Dims, m_Prototext);
58 m_Prototext.append(R"(
59 }
60 }
61 }
62}
63node {
64 name: "input1"
65 op: "Const"
66 attr {
67 key: "dtype"
68 value {
69 type: DT_INT32
70 }
71 }
72 attr {
73 key: "value"
74 value {
75 tensor {
76 dtype: DT_INT32
77 tensor_shape {
78)");
79 dimsHelper(input1Dims, m_Prototext);
80 m_Prototext.append(R"(
81 }
82 tensor_content: ")");
83 octalHelper(input1Content, m_Prototext);
84 m_Prototext.append(R"("
85 }
86 }
87 }
88}
89node {
90 name: "output"
91 op: "Gather"
92 input: "input0"
93 input: "input1"
94 attr {
95 key: "Tindices"
96 value {
97 type: DT_INT32
98 }
99 }
100 attr {
101 key: "Tparams"
102 value {
103 type: DT_FLOAT
104 }
105 }
106}
107 )");
108 Setup({ { "input0", inputShape0 },
109 { "input1", inputShape1 } },
110 { "output" });
111
112 }
113};
114
115
116struct GatherFixture1DParams1DIndices : public GatherFixture
117{
118 GatherFixture1DParams1DIndices() : GatherFixture(
119 { 4, 1, 1, 1 },
120 { 4, 0, 0, 0 },
121 { 0, 2, 1, 3 },
122 { 4 },
123 { 4 }) {}
124};
125
126struct GatherFixture1DParamsMultiDimIndices : public GatherFixture
127{
128 GatherFixture1DParamsMultiDimIndices() : GatherFixture(
129 { 4, 1, 1 },
130 { 2, 2, 1, 1 },
131 { 0, 1, 1, 3 },
132 { 4 },
133 { 2, 2 }) {}
134};
135
136struct GatherFixtureMultiDimParamMultiDimIndices : public GatherFixture
137{
138 GatherFixtureMultiDimParamMultiDimIndices() : GatherFixture(
139 { 5, 2, 1 },
140 { 2, 1, 4 },
141 { 1, 3, 0, 2 },
142 { 5, 2 },
143 { 2, 2 }) {}
144};
145
146BOOST_FIXTURE_TEST_CASE(ParseGather1DParams1DIndices, GatherFixture1DParams1DIndices)
147{
148 RunTest<4>({ { "input0", { 1, 2, 3, 4 } } },
149
150 { { "output", { 1, 3, 2, 4 } } });
151}
152
153BOOST_FIXTURE_TEST_CASE(ParseGather1DParamsMultiDimIndices, GatherFixture1DParamsMultiDimIndices)
154{
155 RunTest<4>({ { "input0", { 1, 2, 3, 4 } } },
156
157 { { "output", { 1, 2, 2, 4 } } });
158}
159
160BOOST_FIXTURE_TEST_CASE(ParseGatherMultiDimParamMultiDimIndices, GatherFixtureMultiDimParamMultiDimIndices)
161{
162 RunTest<4>({ { "input0", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 } } },
163
164 { { "output", { 3, 4, 7, 8, 1, 2, 5, 6} } });
165}
166
167BOOST_AUTO_TEST_SUITE_END()