blob: 4430438bb97d2c960d3f4e7b034b54e1c7813f7f [file] [log] [blame]
Keith Davis4cd29a02019-09-09 14:49:20 +01001//
2// Copyright © 2019 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <boost/test/unit_test.hpp>
7#include "ParserFlatbuffersFixture.hpp"
8#include "../TfLiteParser.hpp"
9
10BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
11
12struct TransposeFixture : public ParserFlatbuffersFixture
13{
14 explicit TransposeFixture(const std::string & inputShape,
15 const std::string & outputShape)
16 {
17 m_JsonString = R"(
18 {
19 "version": 3,
20 "operator_codes": [
21 {
22 "builtin_code": "TRANSPOSE",
23 "version": 1
24 }
25 ],
26 "subgraphs": [
27 {
28 "tensors": [
29 {
30 "shape": )" + inputShape + R"(,
31 "type": "FLOAT32",
32 "buffer": 3,
33 "name": "Placeholder",
34 "quantization": {
35 "min": [
36 0.0
37 ],
38 "max": [
39 255.0
40 ],
41 "details_type": 0,
42 "quantized_dimension": 0
43 },
44 "is_variable": false
45 },
46 {
47 "shape": )" + outputShape + R"(,
48 "type": "FLOAT32",
49 "buffer": 2,
50 "name": "transpose",
51 "quantization": {
52 "details_type": 0,
53 "quantized_dimension": 0
54 },
55 "is_variable": false
56 },
57 {
58 "shape": [
59 3
60 ],
61 "type": "INT32",
62 "buffer": 1,
63 "name": "transpose/perm",
64 "quantization": {
65 "details_type": 0,
66 "quantized_dimension": 0
67 },
68 "is_variable": false
69 }
70 ],
71 "inputs": [
72 0
73 ],
74 "outputs": [
75 1
76 ],
77 "operators": [
78 {
79 "opcode_index": 0,
80 "inputs": [
81 0,
82 2
83 ],
84 "outputs": [
85 1
86 ],
87 "builtin_options_type": "TransposeOptions",
88 "builtin_options": {
89 },
90 "custom_options_format": "FLEXBUFFERS"
91 }
92 ]
93 }
94 ],
95 "description": "TOCO Converted.",
96 "buffers": [
97 { },
98 { },
99 { },
100 { }
101 ]
102 }
103 )";
104 Setup();
105 }
106};
107
108struct SimpleTransposeFixture : TransposeFixture
109{
110 SimpleTransposeFixture() : TransposeFixture("[ 2, 2, 3 ]",
111 "[ 2, 3, 2 ]") {}
112};
113
114BOOST_FIXTURE_TEST_CASE(SimpleTranspose, SimpleTransposeFixture)
115{
116 RunTest<3, armnn::DataType::Float32>(
117 0,
118 {{"Placeholder", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}},
119
120 {{"transpose", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}});
121 BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "transpose").second.GetShape()
122 == armnn::TensorShape({2,3,2})));
123}
124
125BOOST_AUTO_TEST_SUITE_END()