blob: e01c4b8b55f90fad7ab6324e570e3e007a22f538 [file] [log] [blame]
Narumol Prangnawaratcdc495e2021-09-16 18:13:39 +01001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "armnnOnnxParser/IOnnxParser.hpp"
7#include "ParserPrototxtFixture.hpp"
Narumol Prangnawarat452274c2021-09-23 16:12:19 +01008#include "OnnxParserTestUtils.hpp"
Narumol Prangnawaratcdc495e2021-09-16 18:13:39 +01009
10TEST_SUITE("OnnxParser_Shape")
11{
Narumol Prangnawarat452274c2021-09-23 16:12:19 +010012
Narumol Prangnawaratcdc495e2021-09-16 18:13:39 +010013struct ShapeMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
14{
15 ShapeMainFixture(const std::string& inputType,
16 const std::string& outputType,
17 const std::string& outputDim,
18 const std::vector<int>& inputShape)
19 {
20 m_Prototext = R"(
21 ir_version: 8
22 producer_name: "onnx-example"
23 graph {
24 node {
25 input: "Input"
26 output: "Output"
27 op_type: "Shape"
28 }
29 name: "shape-model"
30 input {
31 name: "Input"
32 type {
33 tensor_type {
34 elem_type: )" + inputType + R"(
35 shape {
Narumol Prangnawarat452274c2021-09-23 16:12:19 +010036 )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"(
Narumol Prangnawaratcdc495e2021-09-16 18:13:39 +010037 }
38 }
39 }
40 }
41 output {
42 name: "Output"
43 type {
44 tensor_type {
45 elem_type: )" + outputType + R"(
46 shape {
47 dim {
48 dim_value: )" + outputDim + R"(
49 }
50 }
51 }
52 }
53 }
54 }
55 opset_import {
56 version: 10
57 })";
58 }
Narumol Prangnawaratcdc495e2021-09-16 18:13:39 +010059};
60
Narumol Prangnawarat452274c2021-09-23 16:12:19 +010061struct ShapeFloatFixture : ShapeMainFixture
Narumol Prangnawaratcdc495e2021-09-16 18:13:39 +010062{
Narumol Prangnawarat452274c2021-09-23 16:12:19 +010063 ShapeFloatFixture() : ShapeMainFixture("1", "7", "4", { 1, 3, 1, 5 })
64 {
Narumol Prangnawaratcdc495e2021-09-16 18:13:39 +010065 Setup();
66 }
67};
68
Narumol Prangnawarat452274c2021-09-23 16:12:19 +010069struct ShapeIntFixture : ShapeMainFixture
Narumol Prangnawaratcdc495e2021-09-16 18:13:39 +010070{
Narumol Prangnawarat452274c2021-09-23 16:12:19 +010071 ShapeIntFixture() : ShapeMainFixture("7", "7", "4", { 1, 3, 1, 5 })
72 {
Narumol Prangnawaratcdc495e2021-09-16 18:13:39 +010073 Setup();
74 }
75};
76
77struct Shape3DFixture : ShapeMainFixture
78{
Narumol Prangnawarat452274c2021-09-23 16:12:19 +010079 Shape3DFixture() : ShapeMainFixture("1", "7", "3", { 3, 2, 3 })
80 {
Narumol Prangnawaratcdc495e2021-09-16 18:13:39 +010081 Setup();
82 }
83};
84
85struct Shape2DFixture : ShapeMainFixture
86{
Narumol Prangnawarat452274c2021-09-23 16:12:19 +010087 Shape2DFixture() : ShapeMainFixture("1", "7", "2", { 2, 3 })
88 {
Narumol Prangnawaratcdc495e2021-09-16 18:13:39 +010089 Setup();
90 }
91};
92
93struct Shape1DFixture : ShapeMainFixture
94{
Narumol Prangnawarat452274c2021-09-23 16:12:19 +010095 Shape1DFixture() : ShapeMainFixture("1", "7", "1", { 5 })
96 {
Narumol Prangnawaratcdc495e2021-09-16 18:13:39 +010097 Setup();
98 }
99};
100
Narumol Prangnawarat452274c2021-09-23 16:12:19 +0100101TEST_CASE_FIXTURE(ShapeFloatFixture, "FloatValidShapeTest")
Narumol Prangnawaratcdc495e2021-09-16 18:13:39 +0100102{
Narumol Prangnawarat452274c2021-09-23 16:12:19 +0100103 RunTest<1, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f,
Narumol Prangnawaratcdc495e2021-09-16 18:13:39 +0100104 4.0f, 3.0f, 2.0f, 1.0f, 0.0f,
105 0.0f, 1.0f, 2.0f, 3.0f, 4.0f }}}, {{"Output", { 1, 3, 1, 5 }}});
106}
107
Narumol Prangnawarat452274c2021-09-23 16:12:19 +0100108TEST_CASE_FIXTURE(ShapeIntFixture, "IntValidShapeTest")
Narumol Prangnawaratcdc495e2021-09-16 18:13:39 +0100109{
Narumol Prangnawarat452274c2021-09-23 16:12:19 +0100110 RunTest<1, int>({{"Input", { 0, 1, 2, 3, 4,
Narumol Prangnawaratcdc495e2021-09-16 18:13:39 +0100111 4, 3, 2, 1, 0,
112 0, 1, 2, 3, 4 }}}, {{"Output", { 1, 3, 1, 5 }}});
113}
114
115TEST_CASE_FIXTURE(Shape3DFixture, "Shape3DTest")
116{
Narumol Prangnawarat452274c2021-09-23 16:12:19 +0100117 RunTest<1, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
Narumol Prangnawaratcdc495e2021-09-16 18:13:39 +0100118 5.0f, 4.0f, 3.0f, 2.0f, 1.0f, 0.0f,
119 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }}}, {{"Output", { 3, 2, 3 }}});
120}
121
122TEST_CASE_FIXTURE(Shape2DFixture, "Shape2DTest")
123{
Narumol Prangnawarat452274c2021-09-23 16:12:19 +0100124 RunTest<1, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }}}, {{"Output", { 2, 3 }}});
Narumol Prangnawaratcdc495e2021-09-16 18:13:39 +0100125}
126
127TEST_CASE_FIXTURE(Shape1DFixture, "Shape1DTest")
128{
Narumol Prangnawarat452274c2021-09-23 16:12:19 +0100129 RunTest<1, int>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }}}, {{"Output", { 5 }}});
Narumol Prangnawaratcdc495e2021-09-16 18:13:39 +0100130}
131
132}