blob: ba739cf7b5b2d4140dde27a40d6b085defd2e4b9 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5
6#include <boost/test/unit_test.hpp>
7#include "armnnOnnxParser/IOnnxParser.hpp"
8#include "ParserPrototxtFixture.hpp"
9
10BOOST_AUTO_TEST_SUITE(OnnxParser)
11
12struct ReshapeMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
13{
14 ReshapeMainFixture(const std::string& dataType)
15 {
16 m_Prototext = R"(
17 ir_version: 3
18 producer_name: "CNTK"
19 producer_version: "2.5.1"
20 domain: "ai.cntk"
21 model_version: 1
22 graph {
23 name: "CNTKGraph"
24 input {
25 name: "Input"
26 type {
27 tensor_type {
28 elem_type: )" + dataType + R"(
29 shape {
30 dim {
31 dim_value: 4
32 }
33 }
34 }
35 }
36 }
37 input {
38 name: "Shape"
39 type {
40 tensor_type {
41 elem_type: INT64
42 shape {
43 dim {
44 dim_value: 2
45 }
46 }
47 }
48 }
49 }
50 node {
51 input: "Input"
52 input: "Shape"
53 output: "Output"
54 name: "reshape"
55 op_type: "Reshape"
56
57 }
58 initializer {
59 dims: 2
60 data_type: INT64
61 int64_data: 2
62 int64_data: 2
63 name: "Shape"
64 }
65 output {
66 name: "Output"
67 type {
68 tensor_type {
69 elem_type: FLOAT
70 shape {
71 dim {
72 dim_value: 2
73 }
74 dim {
75 dim_value: 2
76 }
77 }
78 }
79 }
80 }
81 }
82 opset_import {
83 version: 7
84 })";
85 }
86};
87
88struct ReshapeValidFixture : ReshapeMainFixture
89{
90 ReshapeValidFixture() : ReshapeMainFixture("FLOAT") {
91 Setup();
92 }
93};
94
95struct ReshapeInvalidFixture : ReshapeMainFixture
96{
97 ReshapeInvalidFixture() : ReshapeMainFixture("FLOAT16") { }
98};
99
100BOOST_FIXTURE_TEST_CASE(ValidReshapeTest, ReshapeValidFixture)
101{
102 RunTest<2>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f }}}, {{"Output", { 0.0f, 1.0f, 2.0f, 3.0f }}});
103}
104
105BOOST_FIXTURE_TEST_CASE(IncorrectDataTypeReshape, ReshapeInvalidFixture)
106{
107 BOOST_CHECK_THROW(Setup(), armnn::ParseException);
108}
109
110BOOST_AUTO_TEST_SUITE_END()