blob: 4db69996eb5209953b1189d21f713955f2e40f7f [file] [log] [blame]
Keith Davis4cd29a02019-09-09 14:49:20 +01001//
Finn Williamsb49ed182021-06-29 15:50:08 +01002// Copyright © 2019 Arm Ltd and Contributors. All rights reserved.
Keith Davis4cd29a02019-09-09 14:49:20 +01003// SPDX-License-Identifier: MIT
4//
5
Keith Davis4cd29a02019-09-09 14:49:20 +01006#include "ParserFlatbuffersFixture.hpp"
Keith Davis4cd29a02019-09-09 14:49:20 +01007
Sadik Armagan1625efc2021-06-10 18:24:34 +01008TEST_SUITE("TensorflowLiteParser_Transpose")
9{
Keith Davis4cd29a02019-09-09 14:49:20 +010010struct TransposeFixture : public ParserFlatbuffersFixture
11{
12 explicit TransposeFixture(const std::string & inputShape,
Kevin May85d92602019-09-27 17:21:06 +010013 const std::string & permuteData,
Keith Davis4cd29a02019-09-09 14:49:20 +010014 const std::string & outputShape)
15 {
16 m_JsonString = R"(
17 {
18 "version": 3,
19 "operator_codes": [
20 {
21 "builtin_code": "TRANSPOSE",
22 "version": 1
23 }
24 ],
25 "subgraphs": [
26 {
27 "tensors": [
28 {
29 "shape": )" + inputShape + R"(,
30 "type": "FLOAT32",
Kevin May85d92602019-09-27 17:21:06 +010031 "buffer": 0,
32 "name": "inputTensor",
Keith Davis4cd29a02019-09-09 14:49:20 +010033 "quantization": {
34 "min": [
35 0.0
36 ],
37 "max": [
38 255.0
39 ],
40 "details_type": 0,
41 "quantized_dimension": 0
42 },
43 "is_variable": false
44 },
45 {
46 "shape": )" + outputShape + R"(,
47 "type": "FLOAT32",
Keith Davis4cd29a02019-09-09 14:49:20 +010048 "buffer": 1,
Kevin May85d92602019-09-27 17:21:06 +010049 "name": "outputTensor",
Keith Davis4cd29a02019-09-09 14:49:20 +010050 "quantization": {
51 "details_type": 0,
52 "quantized_dimension": 0
53 },
54 "is_variable": false
Kevin May85d92602019-09-27 17:21:06 +010055 })";
josh minorba424d22019-11-13 10:55:17 -060056 m_JsonString += R"(,
57 {
58 "shape": [
59 3
60 ],
61 "type": "INT32",
62 "buffer": 2,
63 "name": "permuteTensor",
64 "quantization": {
65 "details_type": 0,
66 "quantized_dimension": 0
67 },
68 "is_variable": false
69 })";
Kevin May85d92602019-09-27 17:21:06 +010070 m_JsonString += R"(],
Keith Davis4cd29a02019-09-09 14:49:20 +010071 "inputs": [
72 0
73 ],
74 "outputs": [
75 1
76 ],
77 "operators": [
78 {
79 "opcode_index": 0,
80 "inputs": [
Kevin May85d92602019-09-27 17:21:06 +010081 0)";
josh minorba424d22019-11-13 10:55:17 -060082 m_JsonString += R"(,2)";
Kevin May85d92602019-09-27 17:21:06 +010083 m_JsonString += R"(],
Keith Davis4cd29a02019-09-09 14:49:20 +010084 "outputs": [
85 1
86 ],
87 "builtin_options_type": "TransposeOptions",
88 "builtin_options": {
89 },
90 "custom_options_format": "FLEXBUFFERS"
91 }
92 ]
93 }
94 ],
95 "description": "TOCO Converted.",
96 "buffers": [
97 { },
Kevin May85d92602019-09-27 17:21:06 +010098 { })";
99 if (!permuteData.empty())
100 {
101 m_JsonString += R"(,{"data": )" + permuteData + R"( })";
102 }
103 m_JsonString += R"(
Keith Davis4cd29a02019-09-09 14:49:20 +0100104 ]
105 }
106 )";
107 Setup();
108 }
109};
110
josh minorba424d22019-11-13 10:55:17 -0600111// Note that this assumes the Tensorflow permutation vector implementation as opposed to the armnn implemenation.
Kevin May85d92602019-09-27 17:21:06 +0100112struct TransposeFixtureWithPermuteData : TransposeFixture
Keith Davis4cd29a02019-09-09 14:49:20 +0100113{
Kevin May85d92602019-09-27 17:21:06 +0100114 TransposeFixtureWithPermuteData() : TransposeFixture("[ 2, 2, 3 ]",
115 "[ 0, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0 ]",
116 "[ 2, 3, 2 ]") {}
Keith Davis4cd29a02019-09-09 14:49:20 +0100117};
118
Sadik Armagan1625efc2021-06-10 18:24:34 +0100119TEST_CASE_FIXTURE(TransposeFixtureWithPermuteData, "TransposeWithPermuteData")
Keith Davis4cd29a02019-09-09 14:49:20 +0100120{
121 RunTest<3, armnn::DataType::Float32>(
122 0,
josh minorba424d22019-11-13 10:55:17 -0600123 {{"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}},
124 {{"outputTensor", { 1, 4, 2, 5, 3, 6, 7, 10, 8, 11, 9, 12 }}});
Keith Davis4cd29a02019-09-09 14:49:20 +0100125
Sadik Armagan1625efc2021-06-10 18:24:34 +0100126 CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
Kevin May85d92602019-09-27 17:21:06 +0100127 == armnn::TensorShape({2,3,2})));
128}
129
josh minorba424d22019-11-13 10:55:17 -0600130// Tensorflow default permutation behavior assumes no permute argument will create permute vector [n-1...0],
131// where n is the number of dimensions of the input tensor
132// In this case we should get output shape 3,2,2 given default permutation vector 2,1,0
Kevin May85d92602019-09-27 17:21:06 +0100133struct TransposeFixtureWithoutPermuteData : TransposeFixture
134{
135 TransposeFixtureWithoutPermuteData() : TransposeFixture("[ 2, 2, 3 ]",
josh minorba424d22019-11-13 10:55:17 -0600136 "[ 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0 ]",
137 "[ 3, 2, 2 ]") {}
Kevin May85d92602019-09-27 17:21:06 +0100138};
139
Sadik Armagan1625efc2021-06-10 18:24:34 +0100140TEST_CASE_FIXTURE(TransposeFixtureWithoutPermuteData, "TransposeWithoutPermuteDims")
Kevin May85d92602019-09-27 17:21:06 +0100141{
142 RunTest<3, armnn::DataType::Float32>(
143 0,
josh minorba424d22019-11-13 10:55:17 -0600144 {{"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}},
145 {{"outputTensor", { 1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12 }}});
Kevin May85d92602019-09-27 17:21:06 +0100146
Sadik Armagan1625efc2021-06-10 18:24:34 +0100147 CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
josh minorba424d22019-11-13 10:55:17 -0600148 == armnn::TensorShape({3,2,2})));
Keith Davis4cd29a02019-09-09 14:49:20 +0100149}
150
Sadik Armagan1625efc2021-06-10 18:24:34 +0100151}