blob: 7d29d807af2f3ea46063989dc9f65bfa6a754c24 [file] [log] [blame]
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +01001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "armnnOnnxParser/IOnnxParser.hpp"
7#include "ParserPrototxtFixture.hpp"
8
9TEST_SUITE("OnnxParser_LoadScopeDynamicTensor")
10{
11
12struct DynamicBatchTensorFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
13{
14 DynamicBatchTensorFixture()
15 {
16 m_Prototext = R"(
17 ir_version: 3
18 producer_name: "CNTK"
19 producer_version: "2.5.1"
20 domain: "ai.cntk"
21 model_version: 1
22 graph {
23 name: "CNTKGraph"
24 input {
25 name: "Input"
26 type {
27 tensor_type {
28 elem_type: 1
29 shape {
30 dim {
31 dim_value: 0
32 }
33 dim {
34 dim_value: 1
35 }
36 dim {
37 dim_value: 3
38 }
39 dim {
40 dim_value: 3
41 }
42 }
43 }
44 }
45 }
46 input {
47 name: "Weight"
48 type {
49 tensor_type {
50 elem_type: 1
51 shape {
52 dim {
53 dim_value: 1
54 }
55 dim {
56 dim_value: 1
57 }
58 dim {
59 dim_value: 3
60 }
61 dim {
62 dim_value: 3
63 }
64 }
65 }
66 }
67 }
68 initializer {
69 dims: 1
70 dims: 1
71 dims: 3
72 dims: 3
73 data_type: 1
74 float_data: 2
75 float_data: 1
76 float_data: 0
77 float_data: 6
78 float_data: 2
79 float_data: 1
80 float_data: 4
81 float_data: 1
82 float_data: 2
83 name: "Weight"
84 }
85 node {
86 input: "Input"
87 input: "Weight"
88 output: "Output"
89 name: "Convolution"
90 op_type: "Conv"
91 attribute {
92 name: "kernel_shape"
93 ints: 3
94 ints: 3
95 type: INTS
96 }
97 attribute {
98 name: "strides"
99 ints: 1
100 ints: 1
101 type: INTS
102 }
103 attribute {
104 name: "auto_pad"
105 s: "VALID"
106 type: STRING
107 }
108 attribute {
109 name: "group"
110 i: 1
111 type: INT
112 }
113 attribute {
114 name: "dilations"
115 ints: 1
116 ints: 1
117 type: INTS
118 }
119 doc_string: ""
120 domain: ""
121 }
122 output {
123 name: "Output"
124 type {
125 tensor_type {
126 elem_type: 1
127 shape {
128 dim {
129 dim_value: 0
130 }
131 dim {
132 dim_value: 1
133 }
134 dim {
135 dim_value: 1
136 }
137 dim {
138 dim_value: 1
139 }
140 }
141 }
142 }
143 }
144 }
145 opset_import {
146 version: 7
147 })";
148 }
149};
150
151TEST_CASE_FIXTURE(DynamicBatchTensorFixture, "DynamicBatchTensorTest")
152{
153 Setup({{"Input", armnn::TensorShape({1, 1, 3, 3})}});
154 RunTest<4>({{"Input", {1.0, 2.0, 3.0,
155 4.0, 5.0, 6.0,
156 7.0, 8.0, 9.0}}},
157 {{"Output", {1.0 * 2 + 2.0 * 1 + 3.0 * 0 +
158 4.0 * 6 + 5.0 * 2 + 6.0 * 1 +
159 7.0 * 4 + 8.0 * 1 + 9.0 * 2}}});
160}
161
162TEST_CASE_FIXTURE(DynamicBatchTensorFixture, "TensorShapeNotSpecifiedTest")
163{
164 CHECK_THROWS_AS(Setup(), armnn::ParseException);
165}
166
167TEST_CASE_FIXTURE(DynamicBatchTensorFixture, "IncorrectInputNameTest")
168{
169 CHECK_THROWS_AS(Setup({{"Incorrect", armnn::TensorShape({1, 1, 3, 3})}}), armnn::ParseException);
170}
171
172TEST_CASE_FIXTURE(DynamicBatchTensorFixture, "IncorrectBatchTensorTest")
173{
174 Setup({{"Input", armnn::TensorShape({2, 1, 3, 3}) }});
175 CHECK_THROWS_AS(RunTest<4>({{"Input", { 1.0, 2.0, 3.0,
176 4.0, 5.0, 6.0,
177 7.0, 8.0, 9.0 }}},
178 {{"Output", {1.0 * 2 + 2.0 * 1 + 3.0 * 0 +
179 4.0 * 6 + 5.0 * 2 + 6.0 * 1 +
180 7.0 * 4 + 8.0 * 1 + 9.0 * 2 }}}), armnn::Exception);
181
182}
183
184}