blob: e9bcd278cf55c2dc366b0129bdbd78f99d4f7124 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5
telsoa01c577f2c2018-08-31 09:22:23 +01006#include "armnnOnnxParser/IOnnxParser.hpp"
7#include "ParserPrototxtFixture.hpp"
8
Sadik Armagan1625efc2021-06-10 18:24:34 +01009TEST_SUITE("OnnxParser_Reshape")
10{
telsoa01c577f2c2018-08-31 09:22:23 +010011struct ReshapeMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
12{
13 ReshapeMainFixture(const std::string& dataType)
14 {
15 m_Prototext = R"(
16 ir_version: 3
17 producer_name: "CNTK"
18 producer_version: "2.5.1"
19 domain: "ai.cntk"
20 model_version: 1
21 graph {
22 name: "CNTKGraph"
23 input {
24 name: "Input"
25 type {
26 tensor_type {
27 elem_type: )" + dataType + R"(
28 shape {
29 dim {
30 dim_value: 4
31 }
32 }
33 }
34 }
35 }
36 input {
37 name: "Shape"
38 type {
39 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +000040 elem_type: 7
telsoa01c577f2c2018-08-31 09:22:23 +010041 shape {
42 dim {
43 dim_value: 2
44 }
45 }
46 }
47 }
48 }
49 node {
50 input: "Input"
51 input: "Shape"
52 output: "Output"
53 name: "reshape"
54 op_type: "Reshape"
55
56 }
57 initializer {
58 dims: 2
Matteo Martincigh44a71672018-12-11 13:46:52 +000059 data_type: 7
telsoa01c577f2c2018-08-31 09:22:23 +010060 int64_data: 2
61 int64_data: 2
62 name: "Shape"
63 }
64 output {
65 name: "Output"
66 type {
67 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +000068 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +010069 shape {
70 dim {
71 dim_value: 2
72 }
73 dim {
74 dim_value: 2
75 }
76 }
77 }
78 }
79 }
80 }
81 opset_import {
82 version: 7
83 })";
84 }
85};
86
Ryan OSheaed27ee72020-04-22 16:37:29 +010087struct ReshapeRank4Fixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
88{
89 ReshapeRank4Fixture(const std::string& dataType)
90 {
91 m_Prototext = R"(
92 ir_version: 3
93 producer_name: "CNTK"
94 producer_version: "2.5.1"
95 domain: "ai.cntk"
96 model_version: 1
97 graph {
98 name: "CNTKGraph"
99 input {
100 name: "Input"
101 type {
102 tensor_type {
103 elem_type: )" + dataType + R"(
104 shape {
105 dim {
106 dim_value: 2
107 }
108 dim {
109 dim_value: 2
110 }
111 dim {
112 dim_value: 3
113 }
114 dim {
115 dim_value: 3
116 }
117 }
118 }
119 }
120 }
121 input {
122 name: "Shape"
123 type {
124 tensor_type {
125 elem_type: 7
126 shape {
127 dim {
128 dim_value: 2
129 }
130 }
131 }
132 }
133 }
134 node {
135 input: "Input"
136 input: "Shape"
137 output: "Output"
138 name: "reshape"
139 op_type: "Reshape"
140
141 }
142 initializer {
143 dims: 2
144 data_type: 7
145 int64_data: 2
146 int64_data: 2
147 name: "Shape"
148 }
149 output {
150 name: "Output"
151 type {
152 tensor_type {
153 elem_type: 1
154 shape {
155 dim {
156 dim_value: 6
157 }
158 dim {
159 dim_value: 6
160 }
161 }
162 }
163 }
164 }
165 }
166 opset_import {
167 version: 7
168 })";
169 }
170};
171
telsoa01c577f2c2018-08-31 09:22:23 +0100172struct ReshapeValidFixture : ReshapeMainFixture
173{
Matteo Martincigh44a71672018-12-11 13:46:52 +0000174 ReshapeValidFixture() : ReshapeMainFixture("1") {
telsoa01c577f2c2018-08-31 09:22:23 +0100175 Setup();
176 }
177};
178
Ryan OSheaed27ee72020-04-22 16:37:29 +0100179struct ReshapeValidRank4Fixture : ReshapeRank4Fixture
180{
181 ReshapeValidRank4Fixture() : ReshapeRank4Fixture("1") {
182 Setup();
183 }
184};
185
telsoa01c577f2c2018-08-31 09:22:23 +0100186struct ReshapeInvalidFixture : ReshapeMainFixture
187{
Matteo Martincigh44a71672018-12-11 13:46:52 +0000188 ReshapeInvalidFixture() : ReshapeMainFixture("10") { }
telsoa01c577f2c2018-08-31 09:22:23 +0100189};
190
Sadik Armagan1625efc2021-06-10 18:24:34 +0100191TEST_CASE_FIXTURE(ReshapeValidFixture, "ValidReshapeTest")
telsoa01c577f2c2018-08-31 09:22:23 +0100192{
193 RunTest<2>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f }}}, {{"Output", { 0.0f, 1.0f, 2.0f, 3.0f }}});
194}
195
Sadik Armagan1625efc2021-06-10 18:24:34 +0100196TEST_CASE_FIXTURE(ReshapeValidRank4Fixture, "ValidRank4ReshapeTest")
Ryan OSheaed27ee72020-04-22 16:37:29 +0100197{
198 RunTest<2>(
199 {{"Input",
200 {1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
201 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
202 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f}}},
203 {{"Output",
204 {1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
205 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
206 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f}}});
207}
208
Sadik Armagan1625efc2021-06-10 18:24:34 +0100209TEST_CASE_FIXTURE(ReshapeInvalidFixture, "IncorrectDataTypeReshape")
telsoa01c577f2c2018-08-31 09:22:23 +0100210{
Sadik Armagan1625efc2021-06-10 18:24:34 +0100211 CHECK_THROWS_AS(Setup(), armnn::ParseException);
telsoa01c577f2c2018-08-31 09:22:23 +0100212}
213
Sadik Armagan1625efc2021-06-10 18:24:34 +0100214}