Narumol Prangnawarat | cdc495e | 2021-09-16 18:13:39 +0100 | [diff] [blame] | 1 | // |
| 2 | // Copyright © 2021 Arm Ltd and Contributors. All rights reserved. |
| 3 | // SPDX-License-Identifier: MIT |
| 4 | // |
| 5 | |
| 6 | #include "armnnOnnxParser/IOnnxParser.hpp" |
| 7 | #include "ParserPrototxtFixture.hpp" |
Narumol Prangnawarat | 452274c | 2021-09-23 16:12:19 +0100 | [diff] [blame] | 8 | #include "OnnxParserTestUtils.hpp" |
Narumol Prangnawarat | cdc495e | 2021-09-16 18:13:39 +0100 | [diff] [blame] | 9 | |
| 10 | TEST_SUITE("OnnxParser_Shape") |
| 11 | { |
Narumol Prangnawarat | 452274c | 2021-09-23 16:12:19 +0100 | [diff] [blame] | 12 | |
Narumol Prangnawarat | cdc495e | 2021-09-16 18:13:39 +0100 | [diff] [blame] | 13 | struct ShapeMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> |
| 14 | { |
| 15 | ShapeMainFixture(const std::string& inputType, |
| 16 | const std::string& outputType, |
| 17 | const std::string& outputDim, |
| 18 | const std::vector<int>& inputShape) |
| 19 | { |
| 20 | m_Prototext = R"( |
| 21 | ir_version: 8 |
| 22 | producer_name: "onnx-example" |
| 23 | graph { |
| 24 | node { |
| 25 | input: "Input" |
| 26 | output: "Output" |
| 27 | op_type: "Shape" |
| 28 | } |
| 29 | name: "shape-model" |
| 30 | input { |
| 31 | name: "Input" |
| 32 | type { |
| 33 | tensor_type { |
| 34 | elem_type: )" + inputType + R"( |
| 35 | shape { |
Narumol Prangnawarat | 452274c | 2021-09-23 16:12:19 +0100 | [diff] [blame] | 36 | )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"( |
Narumol Prangnawarat | cdc495e | 2021-09-16 18:13:39 +0100 | [diff] [blame] | 37 | } |
| 38 | } |
| 39 | } |
| 40 | } |
| 41 | output { |
| 42 | name: "Output" |
| 43 | type { |
| 44 | tensor_type { |
| 45 | elem_type: )" + outputType + R"( |
| 46 | shape { |
| 47 | dim { |
| 48 | dim_value: )" + outputDim + R"( |
| 49 | } |
| 50 | } |
| 51 | } |
| 52 | } |
| 53 | } |
| 54 | } |
| 55 | opset_import { |
| 56 | version: 10 |
| 57 | })"; |
| 58 | } |
Narumol Prangnawarat | cdc495e | 2021-09-16 18:13:39 +0100 | [diff] [blame] | 59 | }; |
| 60 | |
Narumol Prangnawarat | 452274c | 2021-09-23 16:12:19 +0100 | [diff] [blame] | 61 | struct ShapeFloatFixture : ShapeMainFixture |
Narumol Prangnawarat | cdc495e | 2021-09-16 18:13:39 +0100 | [diff] [blame] | 62 | { |
Narumol Prangnawarat | 452274c | 2021-09-23 16:12:19 +0100 | [diff] [blame] | 63 | ShapeFloatFixture() : ShapeMainFixture("1", "7", "4", { 1, 3, 1, 5 }) |
| 64 | { |
Narumol Prangnawarat | cdc495e | 2021-09-16 18:13:39 +0100 | [diff] [blame] | 65 | Setup(); |
| 66 | } |
| 67 | }; |
| 68 | |
Narumol Prangnawarat | 452274c | 2021-09-23 16:12:19 +0100 | [diff] [blame] | 69 | struct ShapeIntFixture : ShapeMainFixture |
Narumol Prangnawarat | cdc495e | 2021-09-16 18:13:39 +0100 | [diff] [blame] | 70 | { |
Narumol Prangnawarat | 452274c | 2021-09-23 16:12:19 +0100 | [diff] [blame] | 71 | ShapeIntFixture() : ShapeMainFixture("7", "7", "4", { 1, 3, 1, 5 }) |
| 72 | { |
Narumol Prangnawarat | cdc495e | 2021-09-16 18:13:39 +0100 | [diff] [blame] | 73 | Setup(); |
| 74 | } |
| 75 | }; |
| 76 | |
| 77 | struct Shape3DFixture : ShapeMainFixture |
| 78 | { |
Narumol Prangnawarat | 452274c | 2021-09-23 16:12:19 +0100 | [diff] [blame] | 79 | Shape3DFixture() : ShapeMainFixture("1", "7", "3", { 3, 2, 3 }) |
| 80 | { |
Narumol Prangnawarat | cdc495e | 2021-09-16 18:13:39 +0100 | [diff] [blame] | 81 | Setup(); |
| 82 | } |
| 83 | }; |
| 84 | |
| 85 | struct Shape2DFixture : ShapeMainFixture |
| 86 | { |
Narumol Prangnawarat | 452274c | 2021-09-23 16:12:19 +0100 | [diff] [blame] | 87 | Shape2DFixture() : ShapeMainFixture("1", "7", "2", { 2, 3 }) |
| 88 | { |
Narumol Prangnawarat | cdc495e | 2021-09-16 18:13:39 +0100 | [diff] [blame] | 89 | Setup(); |
| 90 | } |
| 91 | }; |
| 92 | |
| 93 | struct Shape1DFixture : ShapeMainFixture |
| 94 | { |
Narumol Prangnawarat | 452274c | 2021-09-23 16:12:19 +0100 | [diff] [blame] | 95 | Shape1DFixture() : ShapeMainFixture("1", "7", "1", { 5 }) |
| 96 | { |
Narumol Prangnawarat | cdc495e | 2021-09-16 18:13:39 +0100 | [diff] [blame] | 97 | Setup(); |
| 98 | } |
| 99 | }; |
| 100 | |
Narumol Prangnawarat | 452274c | 2021-09-23 16:12:19 +0100 | [diff] [blame] | 101 | TEST_CASE_FIXTURE(ShapeFloatFixture, "FloatValidShapeTest") |
Narumol Prangnawarat | cdc495e | 2021-09-16 18:13:39 +0100 | [diff] [blame] | 102 | { |
Narumol Prangnawarat | 452274c | 2021-09-23 16:12:19 +0100 | [diff] [blame] | 103 | RunTest<1, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, |
Narumol Prangnawarat | cdc495e | 2021-09-16 18:13:39 +0100 | [diff] [blame] | 104 | 4.0f, 3.0f, 2.0f, 1.0f, 0.0f, |
| 105 | 0.0f, 1.0f, 2.0f, 3.0f, 4.0f }}}, {{"Output", { 1, 3, 1, 5 }}}); |
| 106 | } |
| 107 | |
Narumol Prangnawarat | 452274c | 2021-09-23 16:12:19 +0100 | [diff] [blame] | 108 | TEST_CASE_FIXTURE(ShapeIntFixture, "IntValidShapeTest") |
Narumol Prangnawarat | cdc495e | 2021-09-16 18:13:39 +0100 | [diff] [blame] | 109 | { |
Narumol Prangnawarat | 452274c | 2021-09-23 16:12:19 +0100 | [diff] [blame] | 110 | RunTest<1, int>({{"Input", { 0, 1, 2, 3, 4, |
Narumol Prangnawarat | cdc495e | 2021-09-16 18:13:39 +0100 | [diff] [blame] | 111 | 4, 3, 2, 1, 0, |
| 112 | 0, 1, 2, 3, 4 }}}, {{"Output", { 1, 3, 1, 5 }}}); |
| 113 | } |
| 114 | |
| 115 | TEST_CASE_FIXTURE(Shape3DFixture, "Shape3DTest") |
| 116 | { |
Narumol Prangnawarat | 452274c | 2021-09-23 16:12:19 +0100 | [diff] [blame] | 117 | RunTest<1, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, |
Narumol Prangnawarat | cdc495e | 2021-09-16 18:13:39 +0100 | [diff] [blame] | 118 | 5.0f, 4.0f, 3.0f, 2.0f, 1.0f, 0.0f, |
| 119 | 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }}}, {{"Output", { 3, 2, 3 }}}); |
| 120 | } |
| 121 | |
| 122 | TEST_CASE_FIXTURE(Shape2DFixture, "Shape2DTest") |
| 123 | { |
Narumol Prangnawarat | 452274c | 2021-09-23 16:12:19 +0100 | [diff] [blame] | 124 | RunTest<1, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }}}, {{"Output", { 2, 3 }}}); |
Narumol Prangnawarat | cdc495e | 2021-09-16 18:13:39 +0100 | [diff] [blame] | 125 | } |
| 126 | |
| 127 | TEST_CASE_FIXTURE(Shape1DFixture, "Shape1DTest") |
| 128 | { |
Narumol Prangnawarat | 452274c | 2021-09-23 16:12:19 +0100 | [diff] [blame] | 129 | RunTest<1, int>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }}}, {{"Output", { 5 }}}); |
Narumol Prangnawarat | cdc495e | 2021-09-16 18:13:39 +0100 | [diff] [blame] | 130 | } |
| 131 | |
| 132 | } |