blob: 5429e567ef539f057a1054d1e95c32dc8b88d8dc [file] [log] [blame]
Keith Davis4cd29a02019-09-09 14:49:20 +01001//
2// Copyright © 2019 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
Keith Davis4cd29a02019-09-09 14:49:20 +01006#include "ParserFlatbuffersFixture.hpp"
7#include "../TfLiteParser.hpp"
8
Sadik Armagan1625efc2021-06-10 18:24:34 +01009TEST_SUITE("TensorflowLiteParser_Transpose")
10{
Keith Davis4cd29a02019-09-09 14:49:20 +010011struct TransposeFixture : public ParserFlatbuffersFixture
12{
13 explicit TransposeFixture(const std::string & inputShape,
Kevin May85d92602019-09-27 17:21:06 +010014 const std::string & permuteData,
Keith Davis4cd29a02019-09-09 14:49:20 +010015 const std::string & outputShape)
16 {
17 m_JsonString = R"(
18 {
19 "version": 3,
20 "operator_codes": [
21 {
22 "builtin_code": "TRANSPOSE",
23 "version": 1
24 }
25 ],
26 "subgraphs": [
27 {
28 "tensors": [
29 {
30 "shape": )" + inputShape + R"(,
31 "type": "FLOAT32",
Kevin May85d92602019-09-27 17:21:06 +010032 "buffer": 0,
33 "name": "inputTensor",
Keith Davis4cd29a02019-09-09 14:49:20 +010034 "quantization": {
35 "min": [
36 0.0
37 ],
38 "max": [
39 255.0
40 ],
41 "details_type": 0,
42 "quantized_dimension": 0
43 },
44 "is_variable": false
45 },
46 {
47 "shape": )" + outputShape + R"(,
48 "type": "FLOAT32",
Keith Davis4cd29a02019-09-09 14:49:20 +010049 "buffer": 1,
Kevin May85d92602019-09-27 17:21:06 +010050 "name": "outputTensor",
Keith Davis4cd29a02019-09-09 14:49:20 +010051 "quantization": {
52 "details_type": 0,
53 "quantized_dimension": 0
54 },
55 "is_variable": false
Kevin May85d92602019-09-27 17:21:06 +010056 })";
josh minorba424d22019-11-13 10:55:17 -060057 m_JsonString += R"(,
58 {
59 "shape": [
60 3
61 ],
62 "type": "INT32",
63 "buffer": 2,
64 "name": "permuteTensor",
65 "quantization": {
66 "details_type": 0,
67 "quantized_dimension": 0
68 },
69 "is_variable": false
70 })";
Kevin May85d92602019-09-27 17:21:06 +010071 m_JsonString += R"(],
Keith Davis4cd29a02019-09-09 14:49:20 +010072 "inputs": [
73 0
74 ],
75 "outputs": [
76 1
77 ],
78 "operators": [
79 {
80 "opcode_index": 0,
81 "inputs": [
Kevin May85d92602019-09-27 17:21:06 +010082 0)";
josh minorba424d22019-11-13 10:55:17 -060083 m_JsonString += R"(,2)";
Kevin May85d92602019-09-27 17:21:06 +010084 m_JsonString += R"(],
Keith Davis4cd29a02019-09-09 14:49:20 +010085 "outputs": [
86 1
87 ],
88 "builtin_options_type": "TransposeOptions",
89 "builtin_options": {
90 },
91 "custom_options_format": "FLEXBUFFERS"
92 }
93 ]
94 }
95 ],
96 "description": "TOCO Converted.",
97 "buffers": [
98 { },
Kevin May85d92602019-09-27 17:21:06 +010099 { })";
100 if (!permuteData.empty())
101 {
102 m_JsonString += R"(,{"data": )" + permuteData + R"( })";
103 }
104 m_JsonString += R"(
Keith Davis4cd29a02019-09-09 14:49:20 +0100105 ]
106 }
107 )";
108 Setup();
109 }
110};
111
josh minorba424d22019-11-13 10:55:17 -0600112// Note that this assumes the Tensorflow permutation vector implementation as opposed to the armnn implemenation.
Kevin May85d92602019-09-27 17:21:06 +0100113struct TransposeFixtureWithPermuteData : TransposeFixture
Keith Davis4cd29a02019-09-09 14:49:20 +0100114{
Kevin May85d92602019-09-27 17:21:06 +0100115 TransposeFixtureWithPermuteData() : TransposeFixture("[ 2, 2, 3 ]",
116 "[ 0, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0 ]",
117 "[ 2, 3, 2 ]") {}
Keith Davis4cd29a02019-09-09 14:49:20 +0100118};
119
Sadik Armagan1625efc2021-06-10 18:24:34 +0100120TEST_CASE_FIXTURE(TransposeFixtureWithPermuteData, "TransposeWithPermuteData")
Keith Davis4cd29a02019-09-09 14:49:20 +0100121{
122 RunTest<3, armnn::DataType::Float32>(
123 0,
josh minorba424d22019-11-13 10:55:17 -0600124 {{"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}},
125 {{"outputTensor", { 1, 4, 2, 5, 3, 6, 7, 10, 8, 11, 9, 12 }}});
Keith Davis4cd29a02019-09-09 14:49:20 +0100126
Sadik Armagan1625efc2021-06-10 18:24:34 +0100127 CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
Kevin May85d92602019-09-27 17:21:06 +0100128 == armnn::TensorShape({2,3,2})));
129}
130
josh minorba424d22019-11-13 10:55:17 -0600131// Tensorflow default permutation behavior assumes no permute argument will create permute vector [n-1...0],
132// where n is the number of dimensions of the input tensor
133// In this case we should get output shape 3,2,2 given default permutation vector 2,1,0
Kevin May85d92602019-09-27 17:21:06 +0100134struct TransposeFixtureWithoutPermuteData : TransposeFixture
135{
136 TransposeFixtureWithoutPermuteData() : TransposeFixture("[ 2, 2, 3 ]",
josh minorba424d22019-11-13 10:55:17 -0600137 "[ 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0 ]",
138 "[ 3, 2, 2 ]") {}
Kevin May85d92602019-09-27 17:21:06 +0100139};
140
Sadik Armagan1625efc2021-06-10 18:24:34 +0100141TEST_CASE_FIXTURE(TransposeFixtureWithoutPermuteData, "TransposeWithoutPermuteDims")
Kevin May85d92602019-09-27 17:21:06 +0100142{
143 RunTest<3, armnn::DataType::Float32>(
144 0,
josh minorba424d22019-11-13 10:55:17 -0600145 {{"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}},
146 {{"outputTensor", { 1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12 }}});
Kevin May85d92602019-09-27 17:21:06 +0100147
Sadik Armagan1625efc2021-06-10 18:24:34 +0100148 CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
josh minorba424d22019-11-13 10:55:17 -0600149 == armnn::TensorShape({3,2,2})));
Keith Davis4cd29a02019-09-09 14:49:20 +0100150}
151
Sadik Armagan1625efc2021-06-10 18:24:34 +0100152}