blob: 2e3190b62ed8d00022febb6dd97c1565f03d790e [file] [log] [blame]
Keith Davis4cd29a02019-09-09 14:49:20 +01001//
2// Copyright © 2019 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <boost/test/unit_test.hpp>
7#include "ParserFlatbuffersFixture.hpp"
8#include "../TfLiteParser.hpp"
9
10BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
11
12struct TransposeFixture : public ParserFlatbuffersFixture
13{
14 explicit TransposeFixture(const std::string & inputShape,
Kevin May85d92602019-09-27 17:21:06 +010015 const std::string & permuteData,
Keith Davis4cd29a02019-09-09 14:49:20 +010016 const std::string & outputShape)
17 {
18 m_JsonString = R"(
19 {
20 "version": 3,
21 "operator_codes": [
22 {
23 "builtin_code": "TRANSPOSE",
24 "version": 1
25 }
26 ],
27 "subgraphs": [
28 {
29 "tensors": [
30 {
31 "shape": )" + inputShape + R"(,
32 "type": "FLOAT32",
Kevin May85d92602019-09-27 17:21:06 +010033 "buffer": 0,
34 "name": "inputTensor",
Keith Davis4cd29a02019-09-09 14:49:20 +010035 "quantization": {
36 "min": [
37 0.0
38 ],
39 "max": [
40 255.0
41 ],
42 "details_type": 0,
43 "quantized_dimension": 0
44 },
45 "is_variable": false
46 },
47 {
48 "shape": )" + outputShape + R"(,
49 "type": "FLOAT32",
Keith Davis4cd29a02019-09-09 14:49:20 +010050 "buffer": 1,
Kevin May85d92602019-09-27 17:21:06 +010051 "name": "outputTensor",
Keith Davis4cd29a02019-09-09 14:49:20 +010052 "quantization": {
53 "details_type": 0,
54 "quantized_dimension": 0
55 },
56 "is_variable": false
Kevin May85d92602019-09-27 17:21:06 +010057 })";
58 if (!permuteData.empty())
59 {
60 m_JsonString += R"(,
61 {
62 "shape": [
63 3
64 ],
65 "type": "INT32",
66 "buffer": 2,
67 "name": "permuteTensor",
68 "quantization": {
69 "details_type": 0,
70 "quantized_dimension": 0
71 },
72 "is_variable": false
73 })";
74 }
75
76 m_JsonString += R"(],
Keith Davis4cd29a02019-09-09 14:49:20 +010077 "inputs": [
78 0
79 ],
80 "outputs": [
81 1
82 ],
83 "operators": [
84 {
85 "opcode_index": 0,
86 "inputs": [
Kevin May85d92602019-09-27 17:21:06 +010087 0)";
88 if (!permuteData.empty())
89 {
90 m_JsonString += R"(,2)";
91 }
92 m_JsonString += R"(],
Keith Davis4cd29a02019-09-09 14:49:20 +010093 "outputs": [
94 1
95 ],
96 "builtin_options_type": "TransposeOptions",
97 "builtin_options": {
98 },
99 "custom_options_format": "FLEXBUFFERS"
100 }
101 ]
102 }
103 ],
104 "description": "TOCO Converted.",
105 "buffers": [
106 { },
Kevin May85d92602019-09-27 17:21:06 +0100107 { })";
108 if (!permuteData.empty())
109 {
110 m_JsonString += R"(,{"data": )" + permuteData + R"( })";
111 }
112 m_JsonString += R"(
Keith Davis4cd29a02019-09-09 14:49:20 +0100113 ]
114 }
115 )";
116 Setup();
117 }
118};
119
Kevin May85d92602019-09-27 17:21:06 +0100120struct TransposeFixtureWithPermuteData : TransposeFixture
Keith Davis4cd29a02019-09-09 14:49:20 +0100121{
Kevin May85d92602019-09-27 17:21:06 +0100122 TransposeFixtureWithPermuteData() : TransposeFixture("[ 2, 2, 3 ]",
123 "[ 0, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0 ]",
124 "[ 2, 3, 2 ]") {}
Keith Davis4cd29a02019-09-09 14:49:20 +0100125};
126
Kevin May85d92602019-09-27 17:21:06 +0100127BOOST_FIXTURE_TEST_CASE(TransposeWithPermuteData, TransposeFixtureWithPermuteData)
Keith Davis4cd29a02019-09-09 14:49:20 +0100128{
129 RunTest<3, armnn::DataType::Float32>(
130 0,
Kevin May85d92602019-09-27 17:21:06 +0100131 {{"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}},
132 {{"outputTensor", { 1, 4, 2, 5, 3, 6, 7, 10, 8, 11, 9, 12 }}});
Keith Davis4cd29a02019-09-09 14:49:20 +0100133
Kevin May85d92602019-09-27 17:21:06 +0100134 BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
135 == armnn::TensorShape({2,3,2})));
136}
137
138struct TransposeFixtureWithoutPermuteData : TransposeFixture
139{
140 TransposeFixtureWithoutPermuteData() : TransposeFixture("[ 2, 2, 3 ]",
141 "",
142 "[ 2, 3, 2 ]") {}
143};
144
145BOOST_FIXTURE_TEST_CASE(TransposeWithoutPermuteDims, TransposeFixtureWithoutPermuteData)
146{
147 RunTest<3, armnn::DataType::Float32>(
148 0,
149 {{"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}},
150 {{"outputTensor", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}});
151
152 BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
Keith Davis4cd29a02019-09-09 14:49:20 +0100153 == armnn::TensorShape({2,3,2})));
154}
155
156BOOST_AUTO_TEST_SUITE_END()