blob: b2f953e75d53cafdf749616410b3e86ef6cf5f42 [file] [log] [blame]
Keith Davis4cd29a02019-09-09 14:49:20 +01001//
2// Copyright © 2019 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <boost/test/unit_test.hpp>
7#include "ParserFlatbuffersFixture.hpp"
8#include "../TfLiteParser.hpp"
9
10BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
11
12struct TransposeFixture : public ParserFlatbuffersFixture
13{
14 explicit TransposeFixture(const std::string & inputShape,
Kevin May85d92602019-09-27 17:21:06 +010015 const std::string & permuteData,
Keith Davis4cd29a02019-09-09 14:49:20 +010016 const std::string & outputShape)
17 {
18 m_JsonString = R"(
19 {
20 "version": 3,
21 "operator_codes": [
22 {
23 "builtin_code": "TRANSPOSE",
24 "version": 1
25 }
26 ],
27 "subgraphs": [
28 {
29 "tensors": [
30 {
31 "shape": )" + inputShape + R"(,
32 "type": "FLOAT32",
Kevin May85d92602019-09-27 17:21:06 +010033 "buffer": 0,
34 "name": "inputTensor",
Keith Davis4cd29a02019-09-09 14:49:20 +010035 "quantization": {
36 "min": [
37 0.0
38 ],
39 "max": [
40 255.0
41 ],
42 "details_type": 0,
43 "quantized_dimension": 0
44 },
45 "is_variable": false
46 },
47 {
48 "shape": )" + outputShape + R"(,
49 "type": "FLOAT32",
Keith Davis4cd29a02019-09-09 14:49:20 +010050 "buffer": 1,
Kevin May85d92602019-09-27 17:21:06 +010051 "name": "outputTensor",
Keith Davis4cd29a02019-09-09 14:49:20 +010052 "quantization": {
53 "details_type": 0,
54 "quantized_dimension": 0
55 },
56 "is_variable": false
Kevin May85d92602019-09-27 17:21:06 +010057 })";
josh minorba424d22019-11-13 10:55:17 -060058 m_JsonString += R"(,
59 {
60 "shape": [
61 3
62 ],
63 "type": "INT32",
64 "buffer": 2,
65 "name": "permuteTensor",
66 "quantization": {
67 "details_type": 0,
68 "quantized_dimension": 0
69 },
70 "is_variable": false
71 })";
Kevin May85d92602019-09-27 17:21:06 +010072 m_JsonString += R"(],
Keith Davis4cd29a02019-09-09 14:49:20 +010073 "inputs": [
74 0
75 ],
76 "outputs": [
77 1
78 ],
79 "operators": [
80 {
81 "opcode_index": 0,
82 "inputs": [
Kevin May85d92602019-09-27 17:21:06 +010083 0)";
josh minorba424d22019-11-13 10:55:17 -060084 m_JsonString += R"(,2)";
Kevin May85d92602019-09-27 17:21:06 +010085 m_JsonString += R"(],
Keith Davis4cd29a02019-09-09 14:49:20 +010086 "outputs": [
87 1
88 ],
89 "builtin_options_type": "TransposeOptions",
90 "builtin_options": {
91 },
92 "custom_options_format": "FLEXBUFFERS"
93 }
94 ]
95 }
96 ],
97 "description": "TOCO Converted.",
98 "buffers": [
99 { },
Kevin May85d92602019-09-27 17:21:06 +0100100 { })";
101 if (!permuteData.empty())
102 {
103 m_JsonString += R"(,{"data": )" + permuteData + R"( })";
104 }
105 m_JsonString += R"(
Keith Davis4cd29a02019-09-09 14:49:20 +0100106 ]
107 }
108 )";
109 Setup();
110 }
111};
112
josh minorba424d22019-11-13 10:55:17 -0600113// Note that this assumes the Tensorflow permutation vector implementation as opposed to the armnn implemenation.
Kevin May85d92602019-09-27 17:21:06 +0100114struct TransposeFixtureWithPermuteData : TransposeFixture
Keith Davis4cd29a02019-09-09 14:49:20 +0100115{
Kevin May85d92602019-09-27 17:21:06 +0100116 TransposeFixtureWithPermuteData() : TransposeFixture("[ 2, 2, 3 ]",
117 "[ 0, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0 ]",
118 "[ 2, 3, 2 ]") {}
Keith Davis4cd29a02019-09-09 14:49:20 +0100119};
120
Kevin May85d92602019-09-27 17:21:06 +0100121BOOST_FIXTURE_TEST_CASE(TransposeWithPermuteData, TransposeFixtureWithPermuteData)
Keith Davis4cd29a02019-09-09 14:49:20 +0100122{
123 RunTest<3, armnn::DataType::Float32>(
124 0,
josh minorba424d22019-11-13 10:55:17 -0600125 {{"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}},
126 {{"outputTensor", { 1, 4, 2, 5, 3, 6, 7, 10, 8, 11, 9, 12 }}});
Keith Davis4cd29a02019-09-09 14:49:20 +0100127
Kevin May85d92602019-09-27 17:21:06 +0100128 BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
129 == armnn::TensorShape({2,3,2})));
130}
131
josh minorba424d22019-11-13 10:55:17 -0600132// Tensorflow default permutation behavior assumes no permute argument will create permute vector [n-1...0],
133// where n is the number of dimensions of the input tensor
134// In this case we should get output shape 3,2,2 given default permutation vector 2,1,0
Kevin May85d92602019-09-27 17:21:06 +0100135struct TransposeFixtureWithoutPermuteData : TransposeFixture
136{
137 TransposeFixtureWithoutPermuteData() : TransposeFixture("[ 2, 2, 3 ]",
josh minorba424d22019-11-13 10:55:17 -0600138 "[ 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0 ]",
139 "[ 3, 2, 2 ]") {}
Kevin May85d92602019-09-27 17:21:06 +0100140};
141
142BOOST_FIXTURE_TEST_CASE(TransposeWithoutPermuteDims, TransposeFixtureWithoutPermuteData)
143{
144 RunTest<3, armnn::DataType::Float32>(
145 0,
josh minorba424d22019-11-13 10:55:17 -0600146 {{"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}},
147 {{"outputTensor", { 1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12 }}});
Kevin May85d92602019-09-27 17:21:06 +0100148
149 BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
josh minorba424d22019-11-13 10:55:17 -0600150 == armnn::TensorShape({3,2,2})));
Keith Davis4cd29a02019-09-09 14:49:20 +0100151}
152
153BOOST_AUTO_TEST_SUITE_END()