blob: a6c20fd63e51dfdd53ec129578556cd1a6da4928 [file] [log] [blame]
FrancisMurtagh94412af2019-01-24 10:53:39 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "armnnTfParser/ITfParser.hpp"
7
8#include "ParserPrototxtFixture.hpp"
9#include <PrototxtConversions.hpp>
10
11#include <boost/test/unit_test.hpp>
12
13BOOST_AUTO_TEST_SUITE(TensorflowParser)
14
Georgios Pinitas5e90aab2020-02-14 14:46:51 +000015namespace {
FrancisMurtagh94412af2019-01-24 10:53:39 +000016// helper for setting the dimensions in prototxt
17void dimsHelper(const std::vector<int>& dims, std::string& text){
Georgios Pinitas5e90aab2020-02-14 14:46:51 +000018 for(u_int i = 0; i < dims.size(); ++i) {
FrancisMurtagh94412af2019-01-24 10:53:39 +000019 text.append(R"(dim {
20 size: )");
21 text.append(std::to_string(dims[i]));
22 text.append(R"(
23 })");
24 }
25}
26
27// helper for converting from integer to octal representation
28void octalHelper(const std::vector<int>& indicesContent, std::string& text){
Georgios Pinitas5e90aab2020-02-14 14:46:51 +000029 for(unsigned int i = 0; i < indicesContent.size(); ++i) {
FrancisMurtagh94412af2019-01-24 10:53:39 +000030 text.append(armnnUtils::ConvertInt32ToOctalString(static_cast<int>(indicesContent[i])));
31 }
32}
Georgios Pinitas5e90aab2020-02-14 14:46:51 +000033} // namespace
FrancisMurtagh94412af2019-01-24 10:53:39 +000034
35struct GatherFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
36{
37 GatherFixture(const armnn::TensorShape& inputShape0,
38 const armnn::TensorShape& inputShape1,
39 const std::vector<int>& input1Content,
40 const std::vector<int>& input0Dims,
41 const std::vector<int>& input1Dims)
42 {
43 m_Prototext = R"(
44node {
45 name: "input0"
46 op: "Placeholder"
47 attr {
48 key: "dtype"
49 value {
50 type: DT_FLOAT
51 }
52 }
53 attr {
54 key: "shape"
55 value {
56 shape {
57)";
58 dimsHelper(input0Dims, m_Prototext);
59 m_Prototext.append(R"(
60 }
61 }
62 }
63}
64node {
65 name: "input1"
66 op: "Const"
67 attr {
68 key: "dtype"
69 value {
70 type: DT_INT32
71 }
72 }
73 attr {
74 key: "value"
75 value {
76 tensor {
77 dtype: DT_INT32
78 tensor_shape {
79)");
80 dimsHelper(input1Dims, m_Prototext);
81 m_Prototext.append(R"(
82 }
83 tensor_content: ")");
84 octalHelper(input1Content, m_Prototext);
85 m_Prototext.append(R"("
86 }
87 }
88 }
89}
90node {
91 name: "output"
92 op: "Gather"
93 input: "input0"
94 input: "input1"
95 attr {
96 key: "Tindices"
97 value {
98 type: DT_INT32
99 }
100 }
101 attr {
102 key: "Tparams"
103 value {
104 type: DT_FLOAT
105 }
106 }
107}
108 )");
109 Setup({ { "input0", inputShape0 },
110 { "input1", inputShape1 } },
111 { "output" });
112
113 }
114};
115
116
117struct GatherFixture1DParams1DIndices : public GatherFixture
118{
119 GatherFixture1DParams1DIndices() : GatherFixture(
120 { 4, 1, 1, 1 },
121 { 4, 0, 0, 0 },
122 { 0, 2, 1, 3 },
123 { 4 },
124 { 4 }) {}
125};
126
127struct GatherFixture1DParamsMultiDimIndices : public GatherFixture
128{
129 GatherFixture1DParamsMultiDimIndices() : GatherFixture(
130 { 4, 1, 1 },
131 { 2, 2, 1, 1 },
132 { 0, 1, 1, 3 },
133 { 4 },
134 { 2, 2 }) {}
135};
136
137struct GatherFixtureMultiDimParamMultiDimIndices : public GatherFixture
138{
139 GatherFixtureMultiDimParamMultiDimIndices() : GatherFixture(
140 { 5, 2, 1 },
141 { 2, 1, 4 },
142 { 1, 3, 0, 2 },
143 { 5, 2 },
144 { 2, 2 }) {}
145};
146
147BOOST_FIXTURE_TEST_CASE(ParseGather1DParams1DIndices, GatherFixture1DParams1DIndices)
148{
149 RunTest<4>({ { "input0", { 1, 2, 3, 4 } } },
150
151 { { "output", { 1, 3, 2, 4 } } });
152}
153
154BOOST_FIXTURE_TEST_CASE(ParseGather1DParamsMultiDimIndices, GatherFixture1DParamsMultiDimIndices)
155{
156 RunTest<4>({ { "input0", { 1, 2, 3, 4 } } },
157
158 { { "output", { 1, 2, 2, 4 } } });
159}
160
161BOOST_FIXTURE_TEST_CASE(ParseGatherMultiDimParamMultiDimIndices, GatherFixtureMultiDimParamMultiDimIndices)
162{
163 RunTest<4>({ { "input0", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 } } },
164
165 { { "output", { 3, 4, 7, 8, 1, 2, 5, 6} } });
166}
167
168BOOST_AUTO_TEST_SUITE_END()